diff --git a/.gitattributes b/.gitattributes index a4c2924d8669ac14ddf1c561e38d7d70fea673a3..79977292b4101015a79d9b72cd06a7d0128976d4 100644 --- a/.gitattributes +++ b/.gitattributes @@ -45,3 +45,11 @@ ComfyUI-KJNodes/fonts/FreeMonoBoldOblique.otf filter=lfs diff=lfs merge=lfs -tex ComfyUI-KJNodes/fonts/TTNorms-Black.otf filter=lfs diff=lfs merge=lfs -text ComfyUI-Kolors-MZ/configs/tokenizer/vocab.txt filter=lfs diff=lfs merge=lfs -text ComfyUI-KwaiKolorsWrapper/configs/tokenizer/vocab.txt filter=lfs diff=lfs merge=lfs -text +PuLID_ComfyUI/examples/pulid_wf.jpg filter=lfs diff=lfs merge=lfs -text +rgthree-comfy/docs/rgthree_advanced_metadata.png filter=lfs diff=lfs merge=lfs -text +rgthree-comfy/docs/rgthree_advanced.png filter=lfs diff=lfs merge=lfs -text +rgthree-comfy/docs/rgthree_context_metadata.png filter=lfs diff=lfs merge=lfs -text +rgthree-comfy/docs/rgthree_context.png filter=lfs diff=lfs merge=lfs -text +x-flux-comfyui/assets/image1.png filter=lfs diff=lfs merge=lfs -text +x-flux-comfyui/guide/manager_menu.png filter=lfs diff=lfs merge=lfs -text +x-flux-comfyui/workflows/example.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/PuLID_ComfyUI/LICENSE b/PuLID_ComfyUI/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/PuLID_ComfyUI/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/PuLID_ComfyUI/README.md b/PuLID_ComfyUI/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8458f51afee1919206ffcda7a61d58f4b0daa76d --- /dev/null +++ b/PuLID_ComfyUI/README.md @@ -0,0 +1,37 @@ +# PuLID ComfyUI + +[PuLID](https://github.com/ToTheBeginning/PuLID) ComfyUI native implementation. + +![basic workflow](examples/pulid_wf.jpg) + +## Important updates + +- **2024.05.12:** Added attention masking and the Advanced node, allows fine tuning of the generation. + +## Notes + +The code can be considered beta, things may change in the coming days. In the `examples` directory you'll find some basic workflows. + +The original implementation makes use of a [4-step lighting UNet](https://huggingface.co/ByteDance/SDXL-Lightning). I made a few comparisons with the official Gradio demo using the same model in ComfyUI and I can't see any noticeable difference, meaning that this code should be faithful to the orignal. The Lightning lora doesn't work as well. + +Testing other models though I noticed some quality degradation. You may need to experiment with CFG and various samplers/schedulers (try `sgm_uniform`). + +**The quality of the reference image is very important**. Maybe this is because of the Eva CLIP that gets more details. Be sure to use a clean and sharp picture! + +**For IPAdapter compatibility you need to update the IPAdapter extension!** + +## The 'method' parameter + +`method` applies the weights in different ways. `Fidelity` is closer to the reference ID, `Style` leaves more freedom to the checkpoint. Sometimes the difference is minimal. I've added `neutral` that doesn't do any normalization, if you use this option with the standard Apply node be sure to lower the weight. With the Advanced node you can simply increase the `fidelity` value. + +The Advanced node has a `fidelity` slider and a `projection` option. `ortho_v2` with `fidelity: 8` is the same as `fidelity` method in the standard node. Projection `ortho` and `fidelity: 16` is the same as method `style`. + +**Lower `fidelity` values grant higher resemblance to the reference image.** + +## Installation + +- [PuLID pre-trained model](https://huggingface.co/huchenlei/ipadapter_pulid/resolve/main/ip-adapter_pulid_sdxl_fp16.safetensors?download=true) goes in `ComfyUI/models/pulid/` (thanks to [Chenlei Hu](https://github.com/huchenlei) for converting them into IPAdapter format) +- The EVA CLIP is EVA02-CLIP-L-14-336, but should be downloaded automatically (will be located in the huggingface directory). +- `facexlib` dependency needs to be installed, the models are downloaded at first use +- Finally you need InsightFace with [AntelopeV2](https://huggingface.co/MonsterMMORPG/tools/tree/main), the unzipped models should be placed in `ComfyUI/models/insightface/models/antelopev2`. + diff --git a/PuLID_ComfyUI/__init__.py b/PuLID_ComfyUI/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d636bc9a9dc589ad0c5fedc54f0412324aa75eac --- /dev/null +++ b/PuLID_ComfyUI/__init__.py @@ -0,0 +1,3 @@ +from .pulid import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS + +__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] diff --git a/PuLID_ComfyUI/encoders.py b/PuLID_ComfyUI/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..941150a56596999ae5d7ac488024831d03cabb78 --- /dev/null +++ b/PuLID_ComfyUI/encoders.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn + +class IDEncoder(nn.Module): + def __init__(self, width=1280, context_dim=2048, num_token=5): + super().__init__() + self.num_token = num_token + self.context_dim = context_dim + h1 = min((context_dim * num_token) // 4, 1024) + h2 = min((context_dim * num_token) // 2, 1024) + self.body = nn.Sequential( + nn.Linear(width, h1), + nn.LayerNorm(h1), + nn.LeakyReLU(), + nn.Linear(h1, h2), + nn.LayerNorm(h2), + nn.LeakyReLU(), + nn.Linear(h2, context_dim * num_token), + ) + + for i in range(5): + setattr( + self, + f'mapping_{i}', + nn.Sequential( + nn.Linear(1024, 1024), + nn.LayerNorm(1024), + nn.LeakyReLU(), + nn.Linear(1024, 1024), + nn.LayerNorm(1024), + nn.LeakyReLU(), + nn.Linear(1024, context_dim), + ), + ) + + setattr( + self, + f'mapping_patch_{i}', + nn.Sequential( + nn.Linear(1024, 1024), + nn.LayerNorm(1024), + nn.LeakyReLU(), + nn.Linear(1024, 1024), + nn.LayerNorm(1024), + nn.LeakyReLU(), + nn.Linear(1024, context_dim), + ), + ) + + def forward(self, x, y): + # x shape [N, C] + x = self.body(x) + x = x.reshape(-1, self.num_token, self.context_dim) + + hidden_states = () + for i, emb in enumerate(y): + hidden_state = getattr(self, f'mapping_{i}')(emb[:, :1]) + getattr(self, f'mapping_patch_{i}')( + emb[:, 1:] + ).mean(dim=1, keepdim=True) + hidden_states += (hidden_state,) + hidden_states = torch.cat(hidden_states, dim=1) + + return torch.cat([x, hidden_states], dim=1) diff --git a/PuLID_ComfyUI/eva_clip/__init__.py b/PuLID_ComfyUI/eva_clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fa2d014bbfe644b1e247758116bbf1b184738fe5 --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/__init__.py @@ -0,0 +1,11 @@ +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_transforms +from .factory import list_models, add_model_config, get_model_config, load_checkpoint +from .loss import ClipLoss +from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\ + convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype +from .openai import load_openai_model, list_openai_models +from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\ + get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained +from .tokenizer import SimpleTokenizer, tokenize +from .transform import image_transform \ No newline at end of file diff --git a/PuLID_ComfyUI/eva_clip/bpe_simple_vocab_16e6.txt.gz b/PuLID_ComfyUI/eva_clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/PuLID_ComfyUI/eva_clip/constants.py b/PuLID_ComfyUI/eva_clip/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..a670bb3fab442baeb9af53b91c312e6982af57ee --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/constants.py @@ -0,0 +1,2 @@ +OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) diff --git a/PuLID_ComfyUI/eva_clip/eva_vit_model.py b/PuLID_ComfyUI/eva_clip/eva_vit_model.py new file mode 100644 index 0000000000000000000000000000000000000000..51db88cf0c7b5d7a43f2be80bc59abb6c859c4b4 --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/eva_vit_model.py @@ -0,0 +1,548 @@ +# -------------------------------------------------------- +# Adapted from https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- +import math +import os +from functools import partial +import torch +import torch.nn as nn +import torch.nn.functional as F +try: + from timm.models.layers import drop_path, to_2tuple, trunc_normal_ +except: + from timm.layers import drop_path, to_2tuple, trunc_normal_ + +from .transformer import PatchDropout +from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast + +if os.getenv('ENV_TYPE') == 'deepspeed': + try: + from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint + except: + from torch.utils.checkpoint import checkpoint +else: + from torch.utils.checkpoint import checkpoint + +try: + import xformers + import xformers.ops as xops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return 'p={}'.format(self.drop_prob) + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + drop=0., + subln=False, + + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + + self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity() + + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + # x = self.drop(x) + # commit this for the orignal BERT implement + x = self.ffn_ln(x) + + x = self.fc2(x) + x = self.drop(x) + return x + +class SwiGLU(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0., + norm_layer=nn.LayerNorm, subln=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.w1 = nn.Linear(in_features, hidden_features) + self.w2 = nn.Linear(in_features, hidden_features) + + self.act = act_layer() + self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity() + self.w3 = nn.Linear(hidden_features, out_features) + + self.drop = nn.Dropout(drop) + + def forward(self, x): + x1 = self.w1(x) + x2 = self.w2(x) + hidden = self.act(x1) * x2 + x = self.ffn_ln(hidden) + x = self.w3(x) + x = self.drop(x) + return x + +class Attention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.subln = subln + if self.subln: + self.q_proj = nn.Linear(dim, all_head_dim, bias=False) + self.k_proj = nn.Linear(dim, all_head_dim, bias=False) + self.v_proj = nn.Linear(dim, all_head_dim, bias=False) + else: + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.v_bias = None + + if window_size: + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + self.attn_drop = nn.Dropout(attn_drop) + self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity() + # self.proj = nn.Linear(all_head_dim, all_head_dim) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.xattn = xattn + self.xattn_drop = attn_drop + + self.rope = rope + + def forward(self, x, rel_pos_bias=None, attn_mask=None): + B, N, C = x.shape + if self.subln: + q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias) + k = F.linear(input=x, weight=self.k_proj.weight, bias=None) + v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias) + + q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C + k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) + v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) + else: + + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C + q, k, v = qkv[0], qkv[1], qkv[2] + + if self.rope: + # slightly fast impl + q_t = q[:, :, 1:, :] + ro_q_t = self.rope(q_t) + q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v) + + k_t = k[:, :, 1:, :] + ro_k_t = self.rope(k_t) + k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v) + + if self.xattn: + q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + x = xops.memory_efficient_attention( + q, k, v, + p=self.xattn_drop, + scale=self.scale, + ) + x = x.reshape(B, N, -1) + x = self.inner_attn_ln(x) + x = self.proj(x) + x = self.proj_drop(x) + else: + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0).type_as(attn) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias.type_as(attn) + + if attn_mask is not None: + attn_mask = attn_mask.bool() + attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf")) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.inner_attn_ln(x) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, + window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False, + subln=False, naiveswiglu=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim, + xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + if naiveswiglu: + self.mlp = SwiGLU( + in_features=dim, + hidden_features=mlp_hidden_dim, + subln=subln, + norm_layer=norm_layer, + ) + else: + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + subln=subln, + drop=drop + ) + + if init_values is not None and init_values > 0: + self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + self.postnorm = postnorm + + def forward(self, x, rel_pos_bias=None, attn_mask=None): + if self.gamma_1 is None: + if self.postnorm: + x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))) + x = x + self.drop_path(self.norm2(self.mlp(x))) + else: + x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + if self.postnorm: + x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))) + x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class RelativePositionBias(nn.Module): + + def __init__(self, window_size, num_heads): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + + def forward(self): + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +class EVAVisionTransformer(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0., + use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False, + use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False, + pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False): + super().__init__() + + if not XFORMERS_IS_AVAILBLE: + xattn = False + + self.image_size = img_size + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + if use_abs_pos_emb: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + else: + self.pos_embed = None + self.pos_drop = nn.Dropout(p=drop_rate) + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) + else: + self.rel_pos_bias = None + + if rope: + half_head_dim = embed_dim // num_heads // 2 + hw_seq_len = img_size // patch_size + self.rope = VisionRotaryEmbeddingFast( + dim=half_head_dim, + pt_seq_len=pt_hw_seq_len, + ft_seq_len=hw_seq_len if intp_freq else None, + # patch_dropout=patch_dropout + ) + else: + self.rope = None + + self.naiveswiglu = naiveswiglu + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.use_rel_pos_bias = use_rel_pos_bias + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None, + xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu) + for i in range(depth)]) + self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) + self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + + trunc_normal_(self.cls_token, std=.02) + # trunc_normal_(self.mask_token, std=.02) + + self.apply(self._init_weights) + self.fix_init_weight() + + if isinstance(self.head, nn.Linear): + trunc_normal_(self.head.weight, std=.02) + self.head.weight.data.mul_(init_scale) + self.head.bias.data.mul_(init_scale) + + # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() + + self.grad_checkpointing = grad_checkpointing + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + if self.naiveswiglu: + rescale(layer.mlp.w3.weight.data, layer_id + 1) + else: + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def get_cast_dtype(self) -> torch.dtype: + return self.blocks[0].mlp.fc2.weight.dtype + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_num_layers(self): + return len(self.blocks) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert unlocked_groups == 0, 'partial locking not currently supported for this model' + for param in self.parameters(): + param.requires_grad = False + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x, return_all_features=False, return_hidden=False, shuffle=False): + + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + if shuffle: + idx = torch.randperm(x.shape[1]) + 1 + zero = torch.LongTensor([0, ]) + idx = torch.cat([zero, idx]) + pos_embed = self.pos_embed[:, idx] + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if shuffle: + x = x + pos_embed + elif self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + if os.getenv('RoPE') == '1': + if self.training and not isinstance(self.patch_dropout, nn.Identity): + x, patch_indices_keep = self.patch_dropout(x) + self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep) + else: + self.rope.forward = partial(self.rope.forward, patch_indices_keep=None) + x = self.patch_dropout(x) + else: + x = self.patch_dropout(x) + + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + hidden_states = [] + for idx, blk in enumerate(self.blocks): + if (0 < idx <= 20) and (idx % 4 == 0) and return_hidden: + hidden_states.append(x) + if self.grad_checkpointing: + x = checkpoint(blk, x, (rel_pos_bias,)) + else: + x = blk(x, rel_pos_bias=rel_pos_bias) + + if not return_all_features: + x = self.norm(x) + if self.fc_norm is not None: + return self.fc_norm(x.mean(1)), hidden_states + else: + return x[:, 0], hidden_states + return x + + def forward(self, x, return_all_features=False, return_hidden=False, shuffle=False): + if return_all_features: + return self.forward_features(x, return_all_features, return_hidden, shuffle) + x, hidden_states = self.forward_features(x, return_all_features, return_hidden, shuffle) + x = self.head(x) + if return_hidden: + return x, hidden_states + return x diff --git a/PuLID_ComfyUI/eva_clip/factory.py b/PuLID_ComfyUI/eva_clip/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..ced8999997bf374b69f846bc73ea635fe8a6eb63 --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/factory.py @@ -0,0 +1,517 @@ +import json +import logging +import os +import pathlib +import re +from copy import deepcopy +from pathlib import Path +from typing import Optional, Tuple, Union, Dict, Any +import torch + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .model import CLIP, CustomCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ + get_cast_dtype +from .openai import load_openai_model +from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model +from .transform import image_transform +from .tokenizer import HFTokenizer, tokenize +from .utils import resize_clip_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed, resize_eva_pos_embed + + +_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] +_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def _rescan_model_configs(): + global _MODEL_CONFIGS + + config_ext = ('.json',) + config_files = [] + for config_path in _MODEL_CONFIG_PATHS: + if config_path.is_file() and config_path.suffix in config_ext: + config_files.append(config_path) + elif config_path.is_dir(): + for ext in config_ext: + config_files.extend(config_path.glob(f'*{ext}')) + + for cf in config_files: + with open(cf, "r", encoding="utf8") as f: + model_cfg = json.load(f) + if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): + _MODEL_CONFIGS[cf.stem] = model_cfg + + _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))) + + +_rescan_model_configs() # initial populate of model config registry + + +def list_models(): + """ enumerate available model architectures based on config files """ + return list(_MODEL_CONFIGS.keys()) + + +def add_model_config(path): + """ add model config path or file and update registry """ + if not isinstance(path, Path): + path = Path(path) + _MODEL_CONFIG_PATHS.append(path) + _rescan_model_configs() + + +def get_model_config(model_name): + if model_name in _MODEL_CONFIGS: + return deepcopy(_MODEL_CONFIGS[model_name]) + else: + return None + + +def get_tokenizer(model_name): + config = get_model_config(model_name) + tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize + return tokenizer + + +# loading openai CLIP weights when is_openai=True for training +def load_state_dict(checkpoint_path: str, map_location: str='cpu', model_key: str='model|module|state_dict', is_openai: bool=False, skip_list: list=[]): + if is_openai: + model = torch.jit.load(checkpoint_path, map_location="cpu").eval() + state_dict = model.state_dict() + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + else: + checkpoint = torch.load(checkpoint_path, map_location=map_location) + for mk in model_key.split('|'): + if isinstance(checkpoint, dict) and mk in checkpoint: + state_dict = checkpoint[mk] + break + else: + state_dict = checkpoint + if next(iter(state_dict.items()))[0].startswith('module'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + for k in skip_list: + if k in list(state_dict.keys()): + logging.info(f"Removing key {k} from pretrained checkpoint") + del state_dict[k] + + if os.getenv('RoPE') == '1': + for k in list(state_dict.keys()): + if 'freqs_cos' in k or 'freqs_sin' in k: + del state_dict[k] + return state_dict + + + +def load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=True): + state_dict = load_state_dict(checkpoint_path, model_key=model_key, is_openai=False) + # detect old format and make compatible with new format + if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): + state_dict = convert_to_custom_text_state_dict(state_dict) + if 'text.logit_scale' in state_dict and hasattr(model, 'logit_scale'): + state_dict['logit_scale'] = state_dict['text.logit_scale'] + del state_dict['text.logit_scale'] + + # resize_clip_pos_embed for CLIP and open CLIP + if 'visual.positional_embedding' in state_dict: + resize_clip_pos_embed(state_dict, model) + # specified to eva_vit_model + elif 'visual.pos_embed' in state_dict: + resize_evaclip_pos_embed(state_dict, model) + + # resize_clip_pos_embed(state_dict, model) + incompatible_keys = model.load_state_dict(state_dict, strict=strict) + logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}") + return incompatible_keys + +def load_clip_visual_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]): + state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list) + + for k in list(state_dict.keys()): + if not k.startswith('visual.'): + del state_dict[k] + for k in list(state_dict.keys()): + if k.startswith('visual.'): + new_k = k[7:] + state_dict[new_k] = state_dict[k] + del state_dict[k] + return state_dict + +def load_clip_text_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]): + state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list) + + for k in list(state_dict.keys()): + if k.startswith('visual.'): + del state_dict[k] + return state_dict + +def get_pretrained_tag(pretrained_model): + pretrained_model = pretrained_model.lower() + if "laion" in pretrained_model or "open_clip" in pretrained_model: + return "open_clip" + elif "openai" in pretrained_model: + return "clip" + elif "eva" in pretrained_model and "clip" in pretrained_model: + return "eva_clip" + else: + return "other" + +def load_pretrained_checkpoint( + model, + visual_checkpoint_path, + text_checkpoint_path, + strict=True, + visual_model=None, + text_model=None, + model_key="model|module|state_dict", + skip_list=[]): + visual_tag = get_pretrained_tag(visual_model) + text_tag = get_pretrained_tag(text_model) + + logging.info(f"num of model state_dict keys: {len(model.state_dict().keys())}") + visual_incompatible_keys, text_incompatible_keys = None, None + if visual_checkpoint_path: + if visual_tag == "eva_clip" or visual_tag == "open_clip": + visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=False, skip_list=skip_list) + elif visual_tag == "clip": + visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list) + else: + visual_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list) + + # resize_clip_pos_embed for CLIP and open CLIP + if 'positional_embedding' in visual_state_dict: + resize_visual_pos_embed(visual_state_dict, model) + # specified to EVA model + elif 'pos_embed' in visual_state_dict: + resize_eva_pos_embed(visual_state_dict, model) + + visual_incompatible_keys = model.visual.load_state_dict(visual_state_dict, strict=strict) + logging.info(f"num of loaded visual_state_dict keys: {len(visual_state_dict.keys())}") + logging.info(f"visual_incompatible_keys.missing_keys: {visual_incompatible_keys.missing_keys}") + + if text_checkpoint_path: + if text_tag == "eva_clip" or text_tag == "open_clip": + text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=False, skip_list=skip_list) + elif text_tag == "clip": + text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=True, skip_list=skip_list) + else: + text_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list) + + text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict) + + logging.info(f"num of loaded text_state_dict keys: {len(text_state_dict.keys())}") + logging.info(f"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}") + + return visual_incompatible_keys, text_incompatible_keys + +def create_model( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_clip: bool = False, + force_patch_dropout: Optional[float] = None, + pretrained_image: str = '', + pretrained_text: str = '', + pretrained_hf: bool = True, + pretrained_visual_model: str = None, + pretrained_text_model: str = None, + cache_dir: Optional[str] = None, + skip_list: list = [], +): + model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names + if isinstance(device, str): + device = torch.device(device) + + if pretrained and pretrained.lower() == 'openai': + logging.info(f'Loading pretrained {model_name} from OpenAI.') + model = load_openai_model( + model_name, + precision=precision, + device=device, + jit=jit, + cache_dir=cache_dir, + ) + else: + model_cfg = get_model_config(model_name) + if model_cfg is not None: + logging.info(f'Loaded {model_name} model config.') + else: + logging.error(f'Model config for {model_name} not found; available models {list_models()}.') + raise RuntimeError(f'Model config for {model_name} not found.') + + if 'rope' in model_cfg.get('vision_cfg', {}): + if model_cfg['vision_cfg']['rope']: + os.environ['RoPE'] = "1" + else: + os.environ['RoPE'] = "0" + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + if force_patch_dropout is not None: + # override the default patch dropout value + model_cfg['vision_cfg']["patch_dropout"] = force_patch_dropout + + cast_dtype = get_cast_dtype(precision) + custom_clip = model_cfg.pop('custom_text', False) or force_custom_clip or ('hf_model_name' in model_cfg['text_cfg']) + + + if custom_clip: + if 'hf_model_name' in model_cfg.get('text_cfg', {}): + model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf + model = CustomCLIP(**model_cfg, cast_dtype=cast_dtype) + else: + model = CLIP(**model_cfg, cast_dtype=cast_dtype) + + pretrained_cfg = {} + if pretrained: + checkpoint_path = '' + pretrained_cfg = get_pretrained_cfg(model_name, pretrained) + if pretrained_cfg: + checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) + elif os.path.exists(pretrained): + checkpoint_path = pretrained + + if checkpoint_path: + logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') + load_checkpoint(model, + checkpoint_path, + model_key="model|module|state_dict", + strict=False + ) + else: + error_str = ( + f'Pretrained weights ({pretrained}) not found for model {model_name}.' + f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') + logging.warning(error_str) + raise RuntimeError(error_str) + else: + visual_checkpoint_path = '' + text_checkpoint_path = '' + + if pretrained_image: + pretrained_visual_model = pretrained_visual_model.replace('/', '-') # for callers using old naming with / in ViT names + pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image) + if 'timm_model_name' in model_cfg.get('vision_cfg', {}): + # pretrained weight loading for timm models set via vision_cfg + model_cfg['vision_cfg']['timm_model_pretrained'] = True + elif pretrained_image_cfg: + visual_checkpoint_path = download_pretrained(pretrained_image_cfg, cache_dir=cache_dir) + elif os.path.exists(pretrained_image): + visual_checkpoint_path = pretrained_image + else: + logging.warning(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.') + raise RuntimeError(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.') + + if pretrained_text: + pretrained_text_model = pretrained_text_model.replace('/', '-') # for callers using old naming with / in ViT names + pretrained_text_cfg = get_pretrained_cfg(pretrained_text_model, pretrained_text) + if pretrained_image_cfg: + text_checkpoint_path = download_pretrained(pretrained_text_cfg, cache_dir=cache_dir) + elif os.path.exists(pretrained_text): + text_checkpoint_path = pretrained_text + else: + logging.warning(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.') + raise RuntimeError(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.') + + if visual_checkpoint_path: + logging.info(f'Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).') + if text_checkpoint_path: + logging.info(f'Loading pretrained {model_name}.text weights ({text_checkpoint_path}).') + + if visual_checkpoint_path or text_checkpoint_path: + load_pretrained_checkpoint( + model, + visual_checkpoint_path, + text_checkpoint_path, + strict=False, + visual_model=pretrained_visual_model, + text_model=pretrained_text_model, + model_key="model|module|state_dict", + skip_list=skip_list + ) + + if "fp16" in precision or "bf16" in precision: + logging.info(f'convert precision to {precision}') + model = model.to(torch.bfloat16) if 'bf16' in precision else model.to(torch.float16) + + model.to(device=device) + + # set image / mean metadata from pretrained_cfg if available, or use default + model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN + model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD + + if jit: + model = torch.jit.script(model) + + return model + + +def create_model_and_transforms( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_clip: bool = False, + force_patch_dropout: Optional[float] = None, + pretrained_image: str = '', + pretrained_text: str = '', + pretrained_hf: bool = True, + pretrained_visual_model: str = None, + pretrained_text_model: str = None, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + cache_dir: Optional[str] = None, + skip_list: list = [], +): + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_clip=force_custom_clip, + force_patch_dropout=force_patch_dropout, + pretrained_image=pretrained_image, + pretrained_text=pretrained_text, + pretrained_hf=pretrained_hf, + pretrained_visual_model=pretrained_visual_model, + pretrained_text_model=pretrained_text_model, + cache_dir=cache_dir, + skip_list=skip_list, + ) + + image_mean = image_mean or getattr(model.visual, 'image_mean', None) + image_std = image_std or getattr(model.visual, 'image_std', None) + preprocess_train = image_transform( + model.visual.image_size, + is_train=True, + mean=image_mean, + std=image_std + ) + preprocess_val = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std + ) + + return model, preprocess_train, preprocess_val + + +def create_transforms( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_clip: bool = False, + force_patch_dropout: Optional[float] = None, + pretrained_image: str = '', + pretrained_text: str = '', + pretrained_hf: bool = True, + pretrained_visual_model: str = None, + pretrained_text_model: str = None, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + cache_dir: Optional[str] = None, + skip_list: list = [], +): + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_clip=force_custom_clip, + force_patch_dropout=force_patch_dropout, + pretrained_image=pretrained_image, + pretrained_text=pretrained_text, + pretrained_hf=pretrained_hf, + pretrained_visual_model=pretrained_visual_model, + pretrained_text_model=pretrained_text_model, + cache_dir=cache_dir, + skip_list=skip_list, + ) + + + image_mean = image_mean or getattr(model.visual, 'image_mean', None) + image_std = image_std or getattr(model.visual, 'image_std', None) + preprocess_train = image_transform( + model.visual.image_size, + is_train=True, + mean=image_mean, + std=image_std + ) + preprocess_val = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std + ) + del model + + return preprocess_train, preprocess_val + +def create_model_from_pretrained( + model_name: str, + pretrained: str, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_clip: bool = False, + force_patch_dropout: Optional[float] = None, + return_transform: bool = True, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + cache_dir: Optional[str] = None, + is_frozen: bool = False, +): + if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained): + raise RuntimeError( + f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.' + f' Use open_clip.list_pretrained() to find one.') + + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_clip=force_custom_clip, + force_patch_dropout=force_patch_dropout, + cache_dir=cache_dir, + ) + + if is_frozen: + for param in model.parameters(): + param.requires_grad = False + + if not return_transform: + return model + + image_mean = image_mean or getattr(model.visual, 'image_mean', None) + image_std = image_std or getattr(model.visual, 'image_std', None) + preprocess = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std + ) + + return model, preprocess diff --git a/PuLID_ComfyUI/eva_clip/hf_configs.py b/PuLID_ComfyUI/eva_clip/hf_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..a8c9b704db1879676aed5cef26796303b65fe987 --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/hf_configs.py @@ -0,0 +1,57 @@ +# HF architecture dict: +arch_dict = { + # https://huggingface.co/docs/transformers/model_doc/roberta#roberta + "roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig + "xlm-roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 + "mt5": { + "config_names": { + # unlimited seqlen + # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 + # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 + "context_length": "", + "vocab_size": "vocab_size", + "width": "d_model", + "heads": "num_heads", + "layers": "num_layers", + "layer_attr": "block", + "token_embeddings_attr": "embed_tokens" + }, + "pooler": "mean_pooler", + }, + "bert": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + } +} diff --git a/PuLID_ComfyUI/eva_clip/hf_model.py b/PuLID_ComfyUI/eva_clip/hf_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c4b9fd85b4066ba31db2bda5767ed1ce15de479d --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/hf_model.py @@ -0,0 +1,248 @@ +""" huggingface model adapter + +Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. +""" + +import re + +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch import TensorType +try: + import transformers + from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig + from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ + BaseModelOutputWithPoolingAndCrossAttentions +except ImportError as e: + transformers = None + + + class BaseModelOutput: + pass + + + class PretrainedConfig: + pass + +from .hf_configs import arch_dict + +# utils +def _camel2snake(s): + return re.sub(r'(? TensorType: + # image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device) + # attn_mask = (x != self.config.pad_token_id).long() + # out = self.transformer( + # input_ids=x, + # attention_mask=attn_mask, + # encoder_hidden_states = image_embeds, + # encoder_attention_mask = image_atts, + # ) + # pooled_out = self.pooler(out, attn_mask) + + # return self.itm_proj(pooled_out) + + def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None): + if masked_indices is None: + masked_indices = torch.bernoulli(probability_matrix).bool() + + masked_indices[input_ids == self.tokenizer.pad_token_id] = False + masked_indices[input_ids == self.tokenizer.cls_token_id] = False + + if targets is not None: + targets[~masked_indices] = -100 # We only compute loss on masked tokens + + # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices + input_ids[indices_replaced] = self.tokenizer.mask_token_id + + # 10% of the time, we replace masked input tokens with random word + indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced + random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device) + input_ids[indices_random] = random_words[indices_random] + # The rest of the time (10% of the time) we keep the masked input tokens unchanged + + if targets is not None: + return input_ids, targets + else: + return input_ids + + def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25): + labels = input_ids.clone() + attn_mask = (input_ids != self.config.pad_token_id).long() + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(input_ids.device) + vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"]) + probability_matrix = torch.full(labels.shape, mlm_probability) + input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels, + probability_matrix = probability_matrix) + mlm_output = self.transformer(input_ids, + attention_mask = attn_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + return_dict = True, + labels = labels, + ) + return mlm_output.loss + # mlm_output = self.transformer(input_ids, + # attention_mask = attn_mask, + # encoder_hidden_states = image_embeds, + # encoder_attention_mask = image_atts, + # return_dict = True, + # ).last_hidden_state + # logits = self.mlm_proj(mlm_output) + + # # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size) + # logits = logits[:, 1:, :].contiguous().view(-1, vocab_size) + # labels = labels[:, 1:].contiguous().view(-1) + + # mlm_loss = F.cross_entropy( + # logits, + # labels, + # # label_smoothing=0.1, + # ) + # return mlm_loss + + + def forward(self, x:TensorType) -> TensorType: + attn_mask = (x != self.config.pad_token_id).long() + out = self.transformer(input_ids=x, attention_mask=attn_mask) + pooled_out = self.pooler(out, attn_mask) + + return self.proj(pooled_out) + + def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True): + if not unlocked_layers: # full freezing + for n, p in self.transformer.named_parameters(): + p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False + return + + encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer + layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"]) + print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model") + embeddings = getattr( + self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"]) + modules = [embeddings, *layer_list][:-unlocked_layers] + # freeze layers + for module in modules: + for n, p in module.named_parameters(): + p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False + + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.gradient_checkpointing_enable() + + def get_num_layers(self): + encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer + layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"]) + return len(layer_list) + + def init_parameters(self): + pass diff --git a/PuLID_ComfyUI/eva_clip/loss.py b/PuLID_ComfyUI/eva_clip/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..473f60d98d501067e85ace2dd089b00e249b6d17 --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/loss.py @@ -0,0 +1,138 @@ +import math +import torch +import torch.nn as nn +from torch.nn import functional as F + +try: + import torch.distributed.nn + from torch import distributed as dist + has_distributed = True +except ImportError: + has_distributed = False + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +from timm.loss import LabelSmoothingCrossEntropy + + +def gather_features( + image_features, + text_features, + local_loss=False, + gather_with_grad=False, + rank=0, + world_size=1, + use_horovod=False +): + assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.' + if use_horovod: + assert hvd is not None, 'Please install horovod' + if gather_with_grad: + all_image_features = hvd.allgather(image_features) + all_text_features = hvd.allgather(text_features) + else: + with torch.no_grad(): + all_image_features = hvd.allgather(image_features) + all_text_features = hvd.allgather(text_features) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) + gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) + gathered_image_features[rank] = image_features + gathered_text_features[rank] = text_features + all_image_features = torch.cat(gathered_image_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + else: + # We gather tensors from all gpus + if gather_with_grad: + all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) + all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) + # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0) + # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0) + else: + gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] + gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] + dist.all_gather(gathered_image_features, image_features) + dist.all_gather(gathered_text_features, text_features) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_image_features[rank] = image_features + gathered_text_features[rank] = text_features + all_image_features = torch.cat(gathered_image_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + + return all_image_features, all_text_features + + +class ClipLoss(nn.Module): + + def __init__( + self, + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + smoothing=0., + ): + super().__init__() + self.local_loss = local_loss + self.gather_with_grad = gather_with_grad + self.cache_labels = cache_labels + self.rank = rank + self.world_size = world_size + self.use_horovod = use_horovod + self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None + + # cache state + self.prev_num_logits = 0 + self.labels = {} + + def forward(self, image_features, text_features, logit_scale=1.): + device = image_features.device + if self.world_size > 1: + all_image_features, all_text_features = gather_features( + image_features, text_features, + self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) + + if self.local_loss: + logits_per_image = logit_scale * image_features @ all_text_features.T + logits_per_text = logit_scale * text_features @ all_image_features.T + else: + logits_per_image = logit_scale * all_image_features @ all_text_features.T + logits_per_text = logits_per_image.T + else: + logits_per_image = logit_scale * image_features @ text_features.T + logits_per_text = logit_scale * text_features @ image_features.T + # calculated ground-truth and cache if enabled + num_logits = logits_per_image.shape[0] + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + + if self.label_smoothing_cross_entropy: + total_loss = ( + self.label_smoothing_cross_entropy(logits_per_image, labels) + + self.label_smoothing_cross_entropy(logits_per_text, labels) + ) / 2 + else: + total_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + acc = None + i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image) + t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text) + acc = {"i2t": i2t_acc, "t2i": t2i_acc} + return total_loss, acc \ No newline at end of file diff --git a/PuLID_ComfyUI/eva_clip/model.py b/PuLID_ComfyUI/eva_clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..90f34b91ece98b3c02ac8f1370d1567df32fa1d4 --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/model.py @@ -0,0 +1,439 @@ +""" CLIP Model + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +import os +from dataclasses import dataclass +from typing import Optional, Tuple, Union +from functools import partial + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +try: + from .hf_model import HFTextEncoder +except: + HFTextEncoder = None +from .modified_resnet import ModifiedResNet +from .timm_model import TimmModel +from .eva_vit_model import EVAVisionTransformer +from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer + +try: + from apex.normalization import FusedLayerNorm +except: + FusedLayerNorm = LayerNorm + print("Nvidia APEX normalization not installed, using PyTorch LayerNorm") + +try: + import xformers.ops as xops +except ImportError: + xops = None + #print("Please 'pip install xformers'") + +@dataclass +class CLIPVisionCfg: + layers: Union[Tuple[int, int, int, int], int] = 12 + width: int = 768 + head_width: int = 64 + mlp_ratio: float = 4.0 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + ls_init_value: Optional[float] = None # layer scale initial value + patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results + global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) + drop_path_rate: Optional[float] = None # drop path rate + timm_model_name: str = None # a valid model name overrides layers, width, patch_size + timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model + timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') + timm_proj_bias: bool = False # enable bias final projection + eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size + qkv_bias: bool = True + fusedLN: bool = False + xattn: bool = False + postnorm: bool = False + rope: bool = False + pt_hw_seq_len: int = 16 # 224/14 + intp_freq: bool = False + naiveswiglu: bool = False + subln: bool = False + + +@dataclass +class CLIPTextCfg: + context_length: int = 77 + vocab_size: int = 49408 + width: int = 512 + heads: int = 8 + layers: int = 12 + ls_init_value: Optional[float] = None # layer scale initial value + hf_model_name: str = None + hf_tokenizer_name: str = None + hf_model_pretrained: bool = True + proj: str = 'mlp' + pooler_type: str = 'mean_pooler' + masked_language_modeling: bool = False + fusedLN: bool = False + xattn: bool = False + attn_mask: bool = True + +def get_cast_dtype(precision: str): + cast_dtype = None + if precision == 'bf16': + cast_dtype = torch.bfloat16 + elif precision == 'fp16': + cast_dtype = torch.float16 + return cast_dtype + + +def _build_vision_tower( + embed_dim: int, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None +): + if isinstance(vision_cfg, dict): + vision_cfg = CLIPVisionCfg(**vision_cfg) + + # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more + # memory efficient in recent PyTorch releases (>= 1.10). + # NOTE: timm models always use native GELU regardless of quick_gelu flag. + act_layer = QuickGELU if quick_gelu else nn.GELU + + if vision_cfg.eva_model_name: + vision_heads = vision_cfg.width // vision_cfg.head_width + norm_layer = LayerNorm + + visual = EVAVisionTransformer( + img_size=vision_cfg.image_size, + patch_size=vision_cfg.patch_size, + num_classes=embed_dim, + use_mean_pooling=vision_cfg.global_average_pool, #False + init_values=vision_cfg.ls_init_value, + patch_dropout=vision_cfg.patch_dropout, + embed_dim=vision_cfg.width, + depth=vision_cfg.layers, + num_heads=vision_heads, + mlp_ratio=vision_cfg.mlp_ratio, + qkv_bias=vision_cfg.qkv_bias, + drop_path_rate=vision_cfg.drop_path_rate, + norm_layer= partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6), + xattn=vision_cfg.xattn, + rope=vision_cfg.rope, + postnorm=vision_cfg.postnorm, + pt_hw_seq_len= vision_cfg.pt_hw_seq_len, # 224/14 + intp_freq= vision_cfg.intp_freq, + naiveswiglu= vision_cfg.naiveswiglu, + subln= vision_cfg.subln + ) + elif vision_cfg.timm_model_name: + visual = TimmModel( + vision_cfg.timm_model_name, + pretrained=vision_cfg.timm_model_pretrained, + pool=vision_cfg.timm_pool, + proj=vision_cfg.timm_proj, + proj_bias=vision_cfg.timm_proj_bias, + embed_dim=embed_dim, + image_size=vision_cfg.image_size + ) + act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models + elif isinstance(vision_cfg.layers, (tuple, list)): + vision_heads = vision_cfg.width * 32 // vision_cfg.head_width + visual = ModifiedResNet( + layers=vision_cfg.layers, + output_dim=embed_dim, + heads=vision_heads, + image_size=vision_cfg.image_size, + width=vision_cfg.width + ) + else: + vision_heads = vision_cfg.width // vision_cfg.head_width + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + visual = VisionTransformer( + image_size=vision_cfg.image_size, + patch_size=vision_cfg.patch_size, + width=vision_cfg.width, + layers=vision_cfg.layers, + heads=vision_heads, + mlp_ratio=vision_cfg.mlp_ratio, + ls_init_value=vision_cfg.ls_init_value, + patch_dropout=vision_cfg.patch_dropout, + global_average_pool=vision_cfg.global_average_pool, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return visual + + +def _build_text_tower( + embed_dim: int, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + if isinstance(text_cfg, dict): + text_cfg = CLIPTextCfg(**text_cfg) + + if text_cfg.hf_model_name: + text = HFTextEncoder( + text_cfg.hf_model_name, + output_dim=embed_dim, + tokenizer_name=text_cfg.hf_tokenizer_name, + proj=text_cfg.proj, + pooler_type=text_cfg.pooler_type, + masked_language_modeling=text_cfg.masked_language_modeling + ) + else: + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = LayerNorm + + text = TextTransformer( + context_length=text_cfg.context_length, + vocab_size=text_cfg.vocab_size, + width=text_cfg.width, + heads=text_cfg.heads, + layers=text_cfg.layers, + ls_init_value=text_cfg.ls_init_value, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer= FusedLayerNorm if text_cfg.fusedLN else norm_layer, + xattn=text_cfg.xattn, + attn_mask=text_cfg.attn_mask, + ) + return text + +class CLIP(nn.Module): + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + + text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.transformer = text.transformer + self.vocab_size = text.vocab_size + self.token_embedding = text.token_embedding + self.positional_embedding = text.positional_embedding + self.ln_final = text.ln_final + self.text_projection = text.text_projection + self.register_buffer('attn_mask', text.attn_mask, persistent=False) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.transformer.grad_checkpointing = enable + + @torch.jit.ignore + def no_weight_decay(self): + return {'logit_scale'} + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + cast_dtype = self.transformer.get_cast_dtype() + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + return F.normalize(x, dim=-1) if normalize else x + + def forward(self, image, text): + image_features = self.encode_image(image, normalize=True) + text_features = self.encode_text(text, normalize=True) + return image_features, text_features, self.logit_scale.exp() + + +class CustomCLIP(nn.Module): + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + itm_task: bool = False, + ): + super().__init__() + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + def lock_text_tower(self, unlocked_layers:int=0, freeze_layer_norm:bool=True): + self.text.lock(unlocked_layers, freeze_layer_norm) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + + @torch.jit.ignore + def no_weight_decay(self): + return {'logit_scale'} + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + features = self.text(text) + return F.normalize(features, dim=-1) if normalize else features + + def forward(self, image, text): + image_features = self.encode_image(image, normalize=True) + text_features = self.encode_text(text, normalize=True) + return image_features, text_features, self.logit_scale.exp() + + +def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): + """Convert applicable model parameters to low-precision (bf16 or fp16)""" + + def _convert_weights(l): + + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.to(dtype) + if l.bias is not None: + l.bias.data = l.bias.data.to(dtype) + + if isinstance(l, (nn.MultiheadAttention, Attention)): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr, None) + if tensor is not None: + tensor.data = tensor.data.to(dtype) + + if isinstance(l, nn.Parameter): + l.data = l.data.to(dtype) + + for name in ["text_projection", "proj"]: + if hasattr(l, name) and isinstance(l, nn.Parameter): + attr = getattr(l, name, None) + if attr is not None: + attr.data = attr.data.to(dtype) + + model.apply(_convert_weights) + + +convert_weights_to_fp16 = convert_weights_to_lp # backwards compat + + +# used to maintain checkpoint compatibility +def convert_to_custom_text_state_dict(state_dict: dict): + if 'text_projection' in state_dict: + # old format state_dict, move text tower -> .text + new_state_dict = {} + for k, v in state_dict.items(): + if any(k.startswith(p) for p in ( + 'text_projection', + 'positional_embedding', + 'token_embedding', + 'transformer', + 'ln_final', + 'logit_scale' + )): + k = 'text.' + k + new_state_dict[k] = v + return new_state_dict + return state_dict + + +def build_model_from_openai_state_dict( + state_dict: dict, + quick_gelu=True, + cast_dtype=torch.float16, +): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_size = vision_patch_size * grid_size + else: + counts: list = [ + len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_size = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + vision_cfg = CLIPVisionCfg( + layers=vision_layers, + width=vision_width, + patch_size=vision_patch_size, + image_size=image_size, + ) + text_cfg = CLIPTextCfg( + context_length=context_length, + vocab_size=vocab_size, + width=transformer_width, + heads=transformer_heads, + layers=transformer_layers + ) + model = CLIP( + embed_dim, + vision_cfg=vision_cfg, + text_cfg=text_cfg, + quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU + cast_dtype=cast_dtype, + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + + convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 + model.load_state_dict(state_dict) + return model.eval() + + +def trace_model(model, batch_size=256, device=torch.device('cpu')): + model.eval() + image_size = model.visual.image_size + example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) + example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) + model = torch.jit.trace_module( + model, + inputs=dict( + forward=(example_images, example_text), + encode_text=(example_text,), + encode_image=(example_images,) + )) + model.visual.image_size = image_size + return model diff --git a/PuLID_ComfyUI/eva_clip/model_configs/EVA01-CLIP-B-16.json b/PuLID_ComfyUI/eva_clip/model_configs/EVA01-CLIP-B-16.json new file mode 100644 index 0000000000000000000000000000000000000000..aad2058003962a4ab286bf4e1ae956288af34e62 --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/model_configs/EVA01-CLIP-B-16.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 16, + "eva_model_name": "eva-clip-b-16", + "ls_init_value": 0.1, + "drop_path_rate": 0.0 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/PuLID_ComfyUI/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json b/PuLID_ComfyUI/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..100279572ff6d1bcca601f0eb526b4d4ff174c7d --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 40, + "width": 1408, + "head_width": 88, + "mlp_ratio": 4.3637, + "patch_size": 14, + "eva_model_name": "eva-clip-g-14-x", + "drop_path_rate": 0, + "xattn": true, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/PuLID_ComfyUI/eva_clip/model_configs/EVA01-CLIP-g-14.json b/PuLID_ComfyUI/eva_clip/model_configs/EVA01-CLIP-g-14.json new file mode 100644 index 0000000000000000000000000000000000000000..5d338b4e6104241d1f0304ee82400035d5385332 --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/model_configs/EVA01-CLIP-g-14.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 40, + "width": 1408, + "head_width": 88, + "mlp_ratio": 4.3637, + "patch_size": 14, + "eva_model_name": "eva-clip-g-14-x", + "drop_path_rate": 0.4, + "xattn": true, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/PuLID_ComfyUI/eva_clip/model_configs/EVA02-CLIP-B-16.json b/PuLID_ComfyUI/eva_clip/model_configs/EVA02-CLIP-B-16.json new file mode 100644 index 0000000000000000000000000000000000000000..e4a6e723f77033caa341ddf9b5be1787d64ad42c --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/model_configs/EVA02-CLIP-B-16.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "head_width": 64, + "patch_size": 16, + "mlp_ratio": 2.6667, + "eva_model_name": "eva-clip-b-16-X", + "drop_path_rate": 0.0, + "xattn": true, + "fusedLN": true, + "rope": true, + "pt_hw_seq_len": 16, + "intp_freq": true, + "naiveswiglu": true, + "subln": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "xattn": true, + "fusedLN": true + } +} \ No newline at end of file diff --git a/PuLID_ComfyUI/eva_clip/model_configs/EVA02-CLIP-L-14-336.json b/PuLID_ComfyUI/eva_clip/model_configs/EVA02-CLIP-L-14-336.json new file mode 100644 index 0000000000000000000000000000000000000000..3e1d124e1118911c5ad7b1ce85df195aca363ac4 --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/model_configs/EVA02-CLIP-L-14-336.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 336, + "layers": 24, + "width": 1024, + "drop_path_rate": 0, + "head_width": 64, + "mlp_ratio": 2.6667, + "patch_size": 14, + "eva_model_name": "eva-clip-l-14-336", + "xattn": true, + "fusedLN": true, + "rope": true, + "pt_hw_seq_len": 16, + "intp_freq": true, + "naiveswiglu": true, + "subln": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/PuLID_ComfyUI/eva_clip/model_configs/EVA02-CLIP-L-14.json b/PuLID_ComfyUI/eva_clip/model_configs/EVA02-CLIP-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..03b22ad3cfb92f9c843b9ec8d672e57e7a9ba4a2 --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/model_configs/EVA02-CLIP-L-14.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "drop_path_rate": 0, + "head_width": 64, + "mlp_ratio": 2.6667, + "patch_size": 14, + "eva_model_name": "eva-clip-l-14", + "xattn": true, + "fusedLN": true, + "rope": true, + "pt_hw_seq_len": 16, + "intp_freq": true, + "naiveswiglu": true, + "subln": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/PuLID_ComfyUI/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json b/PuLID_ComfyUI/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..aa04e2545ac1e015daae2c10133956ce969524f7 --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json @@ -0,0 +1,25 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 64, + "width": 1792, + "head_width": 112, + "mlp_ratio": 8.571428571428571, + "patch_size": 14, + "eva_model_name": "eva-clip-4b-14-x", + "drop_path_rate": 0, + "xattn": true, + "postnorm": true, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32, + "xattn": false, + "fusedLN": true + } +} diff --git a/PuLID_ComfyUI/eva_clip/model_configs/EVA02-CLIP-bigE-14.json b/PuLID_ComfyUI/eva_clip/model_configs/EVA02-CLIP-bigE-14.json new file mode 100644 index 0000000000000000000000000000000000000000..747ffccc8bd49dbb6701b58e15843b7fe3754e64 --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/model_configs/EVA02-CLIP-bigE-14.json @@ -0,0 +1,25 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 64, + "width": 1792, + "head_width": 112, + "mlp_ratio": 8.571428571428571, + "patch_size": 14, + "eva_model_name": "eva-clip-4b-14-x", + "drop_path_rate": 0, + "xattn": true, + "postnorm": true, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/PuLID_ComfyUI/eva_clip/modified_resnet.py b/PuLID_ComfyUI/eva_clip/modified_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..6a8d3aeda91ecb394303becbbfccc8acd8cddcd9 --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/modified_resnet.py @@ -0,0 +1,181 @@ +from collections import OrderedDict + +import torch +from torch import nn +from torch.nn import functional as F + +from .utils import freeze_batch_norm_2d + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.act1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.act2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.act3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.act1(self.bn1(self.conv1(x))) + out = self.act2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.act3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0., + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, image_size=224, width=64): + super().__init__() + self.output_dim = output_dim + self.image_size = image_size + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.act1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.act2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.act3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) + + self.init_parameters() + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def init_parameters(self): + if self.attnpool is not None: + std = self.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert unlocked_groups == 0, 'partial locking not currently supported for this model' + for param in self.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + # FIXME support for non-transformer + pass + + def stem(self, x): + x = self.act1(self.bn1(self.conv1(x))) + x = self.act2(self.bn2(self.conv2(x))) + x = self.act3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + def forward(self, x): + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x diff --git a/PuLID_ComfyUI/eva_clip/openai.py b/PuLID_ComfyUI/eva_clip/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..cc4e13e876d6a7a3463b457e62c517cb063b1356 --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/openai.py @@ -0,0 +1,144 @@ +""" OpenAI pretrained model functions + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import os +import warnings +from typing import List, Optional, Union + +import torch + +from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype +from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url + +__all__ = ["list_openai_models", "load_openai_model"] + + +def list_openai_models() -> List[str]: + """Returns the names of available CLIP models""" + return list_pretrained_models_by_tag('openai') + + +def load_openai_model( + name: str, + precision: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, + jit: bool = True, + cache_dir: Optional[str] = None, +): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + precision: str + Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. + device : Union[str, torch.device] + The device to put the loaded model + jit : bool + Whether to load the optimized JIT model (default) or more hackable non-JIT model. + cache_dir : Optional[str] + The directory to cache the downloaded model weights + + Returns + ------- + model : torch.nn.Module + The CLIP model + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + if precision is None: + precision = 'fp32' if device == 'cpu' else 'fp16' + + if get_pretrained_url(name, 'openai'): + model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(model_path, map_location="cpu") + + if not jit: + # Build a non-jit model from the OpenAI jitted model state dict + cast_dtype = get_cast_dtype(precision) + try: + model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) + except KeyError: + sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} + model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) + + # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use + model = model.to(device) + if precision.startswith('amp') or precision == 'fp32': + model.float() + elif precision == 'bf16': + convert_weights_to_lp(model, dtype=torch.bfloat16) + + return model + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 (typically for CPU) + if precision == 'fp32': + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + model.float() + + # ensure image_size attr available at consistent location for both jit and non-jit + model.visual.image_size = model.input_resolution.item() + return model diff --git a/PuLID_ComfyUI/eva_clip/pretrained.py b/PuLID_ComfyUI/eva_clip/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..a1e55dcf36a0e7dbd4c13b4ca2d7cb460e4c3547 --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/pretrained.py @@ -0,0 +1,332 @@ +import hashlib +import os +import urllib +import warnings +from functools import partial +from typing import Dict, Union + +from tqdm import tqdm + +try: + from huggingface_hub import hf_hub_download + _has_hf_hub = True +except ImportError: + hf_hub_download = None + _has_hf_hub = False + + +def _pcfg(url='', hf_hub='', filename='', mean=None, std=None): + return dict( + url=url, + hf_hub=hf_hub, + mean=mean, + std=std, + ) + +_VITB32 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), + laion2b_e16=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), + laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/') +) + +_VITB32_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), +) + +_VITB16 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), +) + +_EVAB16 = dict( + eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'), + eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'), + eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'), + eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'), +) + +_VITB16_PLUS_240 = dict( + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), +) + +_VITL14 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), + laion2b_s32b_b82k=_pcfg( + hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), +) + +_EVAL14 = dict( + eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'), + eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'), + eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'), + eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'), +) + +_VITL14_336 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), +) + +_EVAL14_336 = dict( + eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'), + eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'), + eva_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'), + eva02_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'), +) + +_VITH14 = dict( + laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), +) + +_VITg14 = dict( + laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), +) + +_EVAg14 = dict( + eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'), + eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'), + eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'), + eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'), +) + +_EVAg14_PLUS = dict( + eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'), + eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'), + eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'), + eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'), +) + +_VITbigG14 = dict( + laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), +) + +_EVAbigE14 = dict( + eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'), + eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'), + eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'), + eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'), +) + +_EVAbigE14_PLUS = dict( + eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'), + eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'), + eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'), + eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'), +) + + +_PRETRAINED = { + # "ViT-B-32": _VITB32, + "OpenaiCLIP-B-32": _VITB32, + "OpenCLIP-B-32": _VITB32, + + # "ViT-B-32-quickgelu": _VITB32_quickgelu, + "OpenaiCLIP-B-32-quickgelu": _VITB32_quickgelu, + "OpenCLIP-B-32-quickgelu": _VITB32_quickgelu, + + # "ViT-B-16": _VITB16, + "OpenaiCLIP-B-16": _VITB16, + "OpenCLIP-B-16": _VITB16, + + "EVA02-B-16": _EVAB16, + "EVA02-CLIP-B-16": _EVAB16, + + # "ViT-B-16-plus-240": _VITB16_PLUS_240, + "OpenCLIP-B-16-plus-240": _VITB16_PLUS_240, + + # "ViT-L-14": _VITL14, + "OpenaiCLIP-L-14": _VITL14, + "OpenCLIP-L-14": _VITL14, + + "EVA02-L-14": _EVAL14, + "EVA02-CLIP-L-14": _EVAL14, + + # "ViT-L-14-336": _VITL14_336, + "OpenaiCLIP-L-14-336": _VITL14_336, + + "EVA02-CLIP-L-14-336": _EVAL14_336, + + # "ViT-H-14": _VITH14, + # "ViT-g-14": _VITg14, + "OpenCLIP-H-14": _VITH14, + "OpenCLIP-g-14": _VITg14, + + "EVA01-CLIP-g-14": _EVAg14, + "EVA01-CLIP-g-14-plus": _EVAg14_PLUS, + + # "ViT-bigG-14": _VITbigG14, + "OpenCLIP-bigG-14": _VITbigG14, + + "EVA02-CLIP-bigE-14": _EVAbigE14, + "EVA02-CLIP-bigE-14-plus": _EVAbigE14_PLUS, +} + + +def _clean_tag(tag: str): + # normalize pretrained tags + return tag.lower().replace('-', '_') + + +def list_pretrained(as_str: bool = False): + """ returns list of pretrained models + Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True + """ + return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] + + +def list_pretrained_models_by_tag(tag: str): + """ return all models having the specified pretrain tag """ + models = [] + tag = _clean_tag(tag) + for k in _PRETRAINED.keys(): + if tag in _PRETRAINED[k]: + models.append(k) + return models + + +def list_pretrained_tags_by_model(model: str): + """ return all pretrain tags for the specified model architecture """ + tags = [] + if model in _PRETRAINED: + tags.extend(_PRETRAINED[model].keys()) + return tags + + +def is_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return False + return _clean_tag(tag) in _PRETRAINED[model] + + +def get_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return {} + model_pretrained = _PRETRAINED[model] + return model_pretrained.get(_clean_tag(tag), {}) + + +def get_pretrained_url(model: str, tag: str): + cfg = get_pretrained_cfg(model, _clean_tag(tag)) + return cfg.get('url', '') + + +def download_pretrained_from_url( + url: str, + cache_dir: Union[str, None] = None, +): + if not cache_dir: + cache_dir = os.path.expanduser("~/.cache/clip") + os.makedirs(cache_dir, exist_ok=True) + filename = os.path.basename(url) + + if 'openaipublic' in url: + expected_sha256 = url.split("/")[-2] + elif 'mlfoundations' in url: + expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] + else: + expected_sha256 = '' + + download_target = os.path.join(cache_dir, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if expected_sha256: + if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + else: + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def has_hf_hub(necessary=False): + if not _has_hf_hub and necessary: + # if no HF Hub module installed, and it is necessary to continue, raise error + raise RuntimeError( + 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') + return _has_hf_hub + + +def download_pretrained_from_hf( + model_id: str, + filename: str = 'open_clip_pytorch_model.bin', + revision=None, + cache_dir: Union[str, None] = None, +): + has_hf_hub(True) + cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) + return cached_file + + +def download_pretrained( + cfg: Dict, + force_hf_hub: bool = False, + cache_dir: Union[str, None] = None, +): + target = '' + if not cfg: + return target + + download_url = cfg.get('url', '') + download_hf_hub = cfg.get('hf_hub', '') + if download_hf_hub and force_hf_hub: + # use HF hub even if url exists + download_url = '' + + if download_url: + target = download_pretrained_from_url(download_url, cache_dir=cache_dir) + elif download_hf_hub: + has_hf_hub(True) + # we assume the hf_hub entries in pretrained config combine model_id + filename in + # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and + # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. + model_id, filename = os.path.split(download_hf_hub) + if filename: + target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) + else: + target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + + return target diff --git a/PuLID_ComfyUI/eva_clip/rope.py b/PuLID_ComfyUI/eva_clip/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..69030c35ea7b6b4f298daebbee5717f3fa1254ab --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/rope.py @@ -0,0 +1,137 @@ +from math import pi +import torch +from torch import nn +from einops import rearrange, repeat +import logging + +def broadcat(tensors, dim = -1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim = dim) + +def rotate_half(x): + x = rearrange(x, '... (d r) -> ... d r', r = 2) + x1, x2 = x.unbind(dim = -1) + x = torch.stack((-x2, x1), dim = -1) + return rearrange(x, '... d r -> ... (d r)') + + +class VisionRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + pt_seq_len, + ft_seq_len=None, + custom_freqs = None, + freqs_for = 'lang', + theta = 10000, + max_freq = 10, + num_freqs = 1, + ): + super().__init__() + if custom_freqs: + freqs = custom_freqs + elif freqs_for == 'lang': + freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + elif freqs_for == 'pixel': + freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi + elif freqs_for == 'constant': + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f'unknown modality {freqs_for}') + + if ft_seq_len is None: ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs_h = torch.einsum('..., f -> ... f', t, freqs) + freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) + + freqs_w = torch.einsum('..., f -> ... f', t, freqs) + freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) + + freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1) + + self.register_buffer("freqs_cos", freqs.cos()) + self.register_buffer("freqs_sin", freqs.sin()) + + logging.info(f'Shape of rope freq: {self.freqs_cos.shape}') + + def forward(self, t, start_index = 0): + rot_dim = self.freqs_cos.shape[-1] + end_index = start_index + rot_dim + assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' + t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] + t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) + + return torch.cat((t_left, t, t_right), dim = -1) + +class VisionRotaryEmbeddingFast(nn.Module): + def __init__( + self, + dim, + pt_seq_len, + ft_seq_len=None, + custom_freqs = None, + freqs_for = 'lang', + theta = 10000, + max_freq = 10, + num_freqs = 1, + patch_dropout = 0. + ): + super().__init__() + if custom_freqs: + freqs = custom_freqs + elif freqs_for == 'lang': + freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + elif freqs_for == 'pixel': + freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi + elif freqs_for == 'constant': + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f'unknown modality {freqs_for}') + + if ft_seq_len is None: ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs = torch.einsum('..., f -> ... f', t, freqs) + freqs = repeat(freqs, '... n -> ... (n r)', r = 2) + freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1) + + freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) + freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) + + self.patch_dropout = patch_dropout + + self.register_buffer("freqs_cos", freqs_cos) + self.register_buffer("freqs_sin", freqs_sin) + + logging.info(f'Shape of rope freq: {self.freqs_cos.shape}') + + def forward(self, t, patch_indices_keep=None): + if patch_indices_keep is not None: + batch = t.size()[0] + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]) + freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]) + + freqs_cos = freqs_cos[batch_indices, patch_indices_keep] + freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j') + freqs_sin = freqs_sin[batch_indices, patch_indices_keep] + freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j') + + return t * freqs_cos + rotate_half(t) * freqs_sin + + return t * self.freqs_cos + rotate_half(t) * self.freqs_sin \ No newline at end of file diff --git a/PuLID_ComfyUI/eva_clip/timm_model.py b/PuLID_ComfyUI/eva_clip/timm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b58122c0b84fbda9e51867342823222234e17505 --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/timm_model.py @@ -0,0 +1,122 @@ +""" timm model adapter + +Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. +""" +import logging +from collections import OrderedDict + +import torch +import torch.nn as nn + +try: + import timm + from timm.models.layers import Mlp, to_2tuple + try: + # old timm imports < 0.8.1 + from timm.models.layers.attention_pool2d import RotAttentionPool2d + from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d + except ImportError: + # new timm imports >= 0.8.1 + from timm.layers import RotAttentionPool2d + from timm.layers import AttentionPool2d as AbsAttentionPool2d +except ImportError: + timm = None + +from .utils import freeze_batch_norm_2d + + +class TimmModel(nn.Module): + """ timm model adapter + # FIXME this adapter is a work in progress, may change in ways that break weight compat + """ + + def __init__( + self, + model_name, + embed_dim, + image_size=224, + pool='avg', + proj='linear', + proj_bias=False, + drop=0., + pretrained=False): + super().__init__() + if timm is None: + raise RuntimeError("Please `pip install timm` to use timm models.") + + self.image_size = to_2tuple(image_size) + self.trunk = timm.create_model(model_name, pretrained=pretrained) + feat_size = self.trunk.default_cfg.get('pool_size', None) + feature_ndim = 1 if not feat_size else 2 + if pool in ('abs_attn', 'rot_attn'): + assert feature_ndim == 2 + # if attn pooling used, remove both classifier and default pool + self.trunk.reset_classifier(0, global_pool='') + else: + # reset global pool if pool config set, otherwise leave as network default + reset_kwargs = dict(global_pool=pool) if pool else {} + self.trunk.reset_classifier(0, **reset_kwargs) + prev_chs = self.trunk.num_features + + head_layers = OrderedDict() + if pool == 'abs_attn': + head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) + prev_chs = embed_dim + elif pool == 'rot_attn': + head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) + prev_chs = embed_dim + else: + assert proj, 'projection layer needed if non-attention pooling is used.' + + # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used + if proj == 'linear': + head_layers['drop'] = nn.Dropout(drop) + head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) + elif proj == 'mlp': + head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias)) + + self.head = nn.Sequential(head_layers) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + """ lock modules + Args: + unlocked_groups (int): leave last n layer groups unlocked (default: 0) + """ + if not unlocked_groups: + # lock full model + for param in self.trunk.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self.trunk) + else: + # NOTE: partial freeze requires latest timm (master) branch and is subject to change + try: + # FIXME import here until API stable and in an official release + from timm.models.helpers import group_parameters, group_modules + except ImportError: + raise RuntimeError( + 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') + matcher = self.trunk.group_matcher() + gparams = group_parameters(self.trunk, matcher) + max_layer_id = max(gparams.keys()) + max_layer_id = max_layer_id - unlocked_groups + for group_idx in range(max_layer_id + 1): + group = gparams[group_idx] + for param in group: + self.trunk.get_parameter(param).requires_grad = False + if freeze_bn_stats: + gmodules = group_modules(self.trunk, matcher, reverse=True) + gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} + freeze_batch_norm_2d(self.trunk, gmodules) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + try: + self.trunk.set_grad_checkpointing(enable) + except Exception as e: + logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') + + def forward(self, x): + x = self.trunk(x) + x = self.head(x) + return x diff --git a/PuLID_ComfyUI/eva_clip/tokenizer.py b/PuLID_ComfyUI/eva_clip/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..41482f82aebbf197f4ee4e6c07c845a0d69dd7d6 --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/tokenizer.py @@ -0,0 +1,201 @@ +""" CLIP tokenizer + +Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +import gzip +import html +import os +from functools import lru_cache +from typing import Union, List + +import ftfy +import regex as re +import torch + +# https://stackoverflow.com/q/62691279 +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + if not special_tokens: + special_tokens = ['', ''] + else: + special_tokens = ['', ''] + special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t:t for t in special_tokens} + special = "|".join(special_tokens) + self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text + + +_tokenizer = SimpleTokenizer() + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder[""] + eot_token = _tokenizer.encoder[""] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + tokens[-1] = eot_token + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + +class HFTokenizer: + "HuggingFace tokenizer wrapper" + def __init__(self, tokenizer_name:str): + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor: + # same cleaning as for default tokenizer, except lowercasing + # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance + if isinstance(texts, str): + texts = [texts] + texts = [whitespace_clean(basic_clean(text)) for text in texts] + input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids + return input_ids diff --git a/PuLID_ComfyUI/eva_clip/transform.py b/PuLID_ComfyUI/eva_clip/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..39f3e4cf6cf9985131ae2ef254b59540904b02e7 --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/transform.py @@ -0,0 +1,103 @@ +from typing import Optional, Sequence, Tuple + +import torch +import torch.nn as nn +import torchvision.transforms.functional as F + +from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ + CenterCrop + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD + + +class ResizeMaxSize(nn.Module): + + def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): + super().__init__() + if not isinstance(max_size, int): + raise TypeError(f"Size should be int. Got {type(max_size)}") + self.max_size = max_size + self.interpolation = interpolation + self.fn = min if fn == 'min' else min + self.fill = fill + + def forward(self, img): + if isinstance(img, torch.Tensor): + height, width = img.shape[:2] + else: + width, height = img.size + scale = self.max_size / float(max(height, width)) + if scale != 1.0: + new_size = tuple(round(dim * scale) for dim in (height, width)) + img = F.resize(img, new_size, self.interpolation) + pad_h = self.max_size - new_size[0] + pad_w = self.max_size - new_size[1] + img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) + return img + + +def _convert_to_rgb(image): + return image.convert('RGB') + + +# class CatGen(nn.Module): +# def __init__(self, num=4): +# self.num = num +# def mixgen_batch(image, text): +# batch_size = image.shape[0] +# index = np.random.permutation(batch_size) + +# cat_images = [] +# for i in range(batch_size): +# # image mixup +# image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:] +# # text concat +# text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0] +# text = torch.stack(text) +# return image, text + + +def image_transform( + image_size: int, + is_train: bool, + mean: Optional[Tuple[float, ...]] = None, + std: Optional[Tuple[float, ...]] = None, + resize_longest_max: bool = False, + fill_color: int = 0, +): + mean = mean or OPENAI_DATASET_MEAN + if not isinstance(mean, (list, tuple)): + mean = (mean,) * 3 + + std = std or OPENAI_DATASET_STD + if not isinstance(std, (list, tuple)): + std = (std,) * 3 + + if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: + # for square size, pass size as int so that Resize() uses aspect preserving shortest edge + image_size = image_size[0] + + normalize = Normalize(mean=mean, std=std) + if is_train: + return Compose([ + RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), + _convert_to_rgb, + ToTensor(), + normalize, + ]) + else: + if resize_longest_max: + transforms = [ + ResizeMaxSize(image_size, fill=fill_color) + ] + else: + transforms = [ + Resize(image_size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(image_size), + ] + transforms.extend([ + _convert_to_rgb, + ToTensor(), + normalize, + ]) + return Compose(transforms) diff --git a/PuLID_ComfyUI/eva_clip/transformer.py b/PuLID_ComfyUI/eva_clip/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..33e89ff7aa8ff60ae65dcfc5d21cf9af4d214510 --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/transformer.py @@ -0,0 +1,737 @@ +import os +import logging +from collections import OrderedDict +import math +from typing import Callable, Optional, Sequence +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +try: + from timm.models.layers import trunc_normal_ +except: + from timm.layers import trunc_normal_ + +from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast +from .utils import to_2tuple + +if os.getenv('ENV_TYPE') == 'deepspeed': + try: + import deepspeed + from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint + except: + print("Please 'pip install deepspeed'") + deepspeed = None + from torch.utils.checkpoint import checkpoint +else: + from torch.utils.checkpoint import checkpoint + +try: + import xformers.ops as xops +except ImportError: + xops = None + print("Please 'pip install xformers'") + +class LayerNormFp32(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x: torch.Tensor): + output = F.layer_norm( + x.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(x) + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm (with cast back to input dtype).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}") + + def forward(self, x): + if not self.training or self.prob == 0.: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + if self.training and os.getenv('RoPE') == '1': + return x, patch_indices_keep + + return x + + +def _in_projection_packed( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: Optional[torch.Tensor] = None, + ): + """ + https://github.com/pytorch/pytorch/blob/db2a237763eb8693a20788be94f8c192e762baa8/torch/nn/functional.py#L4726 + """ + E = q.size(-1) + if k is v: + if q is k: + # self-attention + return F.linear(q, w, b).chunk(3, dim=-1) + else: + # encoder-decoder attention + w_q, w_kv = w.split([E, E * 2]) + if b is None: + b_q = b_kv = None + else: + b_q, b_kv = b.split([E, E * 2]) + return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1) + else: + w_q, w_k, w_v = w.chunk(3) + if b is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = b.chunk(3) + return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v) + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + scaled_cosine=False, + scale_heads=False, + logit_scale_max=math.log(1. / 0.01), + attn_drop=0., + proj_drop=0., + xattn=False, + rope=False + ): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.logit_scale_max = logit_scale_max + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None + + if self.scaled_cosine: + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = nn.Linear(dim, dim) + self.out_drop = nn.Dropout(proj_drop) + self.xattn = xattn + self.xattn_drop = attn_drop + self.rope = rope + + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + L, N, C = x.shape + q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) + if self.xattn: + q = q.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1) + + x = xops.memory_efficient_attention( + q, k, v, + p=self.xattn_drop, + scale=self.scale if self.logit_scale is None else None, + attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None, + ) + else: + q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + + if self.logit_scale is not None: + attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.view(N, self.num_heads, L, L) * logit_scale + attn = attn.view(-1, L, L) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-1, -2)) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + attn += attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = torch.bmm(attn, v) + + if self.head_scale is not None: + x = x.view(N, self.num_heads, L, C) * self.head_scale + x = x.view(-1, L, C) + x = x.transpose(0, 1).reshape(L, N, C) + x = self.out_proj(x) + x = self.out_drop(x) + return x + +class CustomAttention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + scaled_cosine=True, + scale_heads=False, + logit_scale_max=math.log(1. / 0.01), + attn_drop=0., + proj_drop=0., + xattn=False + ): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.logit_scale_max = logit_scale_max + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None + + if self.scaled_cosine: + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = nn.Linear(dim, dim) + self.out_drop = nn.Dropout(proj_drop) + self.xattn = xattn + self.xattn_drop = attn_drop + + def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias) + N_q, B_q, C_q = q.shape + N_k, B_k, C_k = k.shape + N_v, B_v, C_v = v.shape + if self.xattn: + # B, N, C -> B, N, num_heads, C + q = q.permute(1, 0, 2).reshape(B_q, N_q, self.num_heads, -1) + k = k.permute(1, 0, 2).reshape(B_k, N_k, self.num_heads, -1) + v = v.permute(1, 0, 2).reshape(B_v, N_v, self.num_heads, -1) + + x = xops.memory_efficient_attention( + q, k, v, + p=self.xattn_drop, + scale=self.scale if self.logit_scale is None else None, + attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None + ) + else: + # B*H, L, C + q = q.contiguous().view(N_q, B_q * self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(N_k, B_k * self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(N_v, B_v * self.num_heads, -1).transpose(0, 1) + + if self.logit_scale is not None: + # B*H, N_q, N_k + attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.view(B_q, self.num_heads, N_q, N_k) * logit_scale + attn = attn.view(-1, N_q, N_k) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-1, -2)) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + attn += attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = torch.bmm(attn, v) + + if self.head_scale is not None: + x = x.view(B_q, self.num_heads, N_q, C_q) * self.head_scale + x = x.view(-1, N_q, C_q) + x = x.transpose(0, 1).reshape(N_q, B_q, C_q) + x = self.out_proj(x) + x = self.out_drop(x) + return x + +class CustomResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + scale_cosine_attn: bool = False, + scale_heads: bool = False, + scale_attn: bool = False, + scale_fc: bool = False, + cross_attn: bool = False, + xattn: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.ln_1_k = norm_layer(d_model) if cross_attn else self.ln_1 + self.ln_1_v = norm_layer(d_model) if cross_attn else self.ln_1 + self.attn = CustomAttention( + d_model, n_head, + qkv_bias=True, + attn_drop=0., + proj_drop=0., + scaled_cosine=scale_cosine_attn, + scale_heads=scale_heads, + xattn=xattn + ) + + self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + q = q + self.ls_1(self.ln_attn(self.attn(self.ln_1(q), self.ln_1_k(k), self.ln_1_v(v), attn_mask=attn_mask))) + q = q + self.ls_2(self.mlp(self.ln_2(q))) + return q + +class CustomTransformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + scale_cosine_attn: bool = True, + scale_heads: bool = False, + scale_attn: bool = False, + scale_fc: bool = False, + cross_attn: bool = False, + xattn: bool = False, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + self.xattn = xattn + + self.resblocks = nn.ModuleList([ + CustomResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + scale_cosine_attn=scale_cosine_attn, + scale_heads=scale_heads, + scale_attn=scale_attn, + scale_fc=scale_fc, + cross_attn=cross_attn, + xattn=xattn) + for _ in range(layers) + ]) + + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + + def forward(self, q: torch.Tensor, k: torch.Tensor = None, v: torch.Tensor = None, attn_mask: Optional[torch.Tensor] = None): + if k is None and v is None: + k = v = q + for r in self.resblocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + q = checkpoint(r, q, k, v, attn_mask) + else: + q = r(q, k, v, attn_mask=attn_mask) + return q + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + xattn: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + if xattn: + self.attn = Attention(d_model, n_head, xattn=True) + else: + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + self.xattn = xattn + + def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None + if self.xattn: + return self.attn(x, attn_mask=attn_mask) + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask)) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + +class Transformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + xattn: bool = False, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock( + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn) + for _ in range(layers) + ]) + + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + +class VisionTransformer(nn.Module): + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + ls_init_value: float = None, + patch_dropout: float = 0., + global_average_pool: bool = False, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + xattn: bool = False, + ): + super().__init__() + self.image_size = to_2tuple(image_size) + self.patch_size = to_2tuple(patch_size) + self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1]) + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) + + # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() + self.ln_pre = norm_layer(width) + + self.transformer = Transformer( + width, + layers, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + xattn=xattn + ) + + self.global_average_pool = global_average_pool + self.ln_post = norm_layer(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + for param in self.parameters(): + param.requires_grad = False + + if unlocked_groups != 0: + groups = [ + [ + self.conv1, + self.class_embedding, + self.positional_embedding, + self.ln_pre, + ], + *self.transformer.resblocks[:-1], + [ + self.transformer.resblocks[-1], + self.ln_post, + ], + self.proj, + ] + + def _unlock(x): + if isinstance(x, Sequence): + for g in x: + _unlock(g) + else: + if isinstance(x, torch.nn.Parameter): + x.requires_grad = True + else: + for p in x.parameters(): + p.requires_grad = True + + _unlock(groups[-unlocked_groups:]) + + def get_num_layers(self): + return self.transformer.layers + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + @torch.jit.ignore + def no_weight_decay(self): + return {'positional_embedding', 'class_embedding'} + + def forward(self, x: torch.Tensor, return_all_features: bool=False): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat( + [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + x = self.patch_dropout(x) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + if not return_all_features: + if self.global_average_pool: + x = x.mean(dim=1) #x = x[:,1:,:].mean(dim=1) + else: + x = x[:, 0] + + x = self.ln_post(x) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class TextTransformer(nn.Module): + def __init__( + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + ls_init_value: float = None, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + xattn: bool= False, + attn_mask: bool = True + ): + super().__init__() + self.context_length = context_length + self.vocab_size = vocab_size + self.width = width + self.output_dim = output_dim + + self.token_embedding = nn.Embedding(vocab_size, width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width)) + self.transformer = Transformer( + width=width, + layers=layers, + heads=heads, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + xattn=xattn + ) + + self.xattn = xattn + self.ln_final = norm_layer(width) + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + if attn_mask: + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + else: + self.attn_mask = None + + self.init_parameters() + + def init_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + @torch.jit.ignore + def no_weight_decay(self): + # return {'positional_embedding', 'token_embedding'} + return {'positional_embedding'} + + def get_num_layers(self): + return self.transformer.layers + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, text, return_all_features: bool=False): + cast_dtype = self.transformer.get_cast_dtype() + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=self.attn_mask) + # x = self.transformer(x) # no attention mask is applied + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + if not return_all_features: + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + return x diff --git a/PuLID_ComfyUI/eva_clip/utils.py b/PuLID_ComfyUI/eva_clip/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bdc5a7a451fdf8911ebbc816afbd2664ff348836 --- /dev/null +++ b/PuLID_ComfyUI/eva_clip/utils.py @@ -0,0 +1,326 @@ +from itertools import repeat +import collections.abc +import logging +import math +import numpy as np + +import torch +from torch import nn as nn +from torchvision.ops.misc import FrozenBatchNorm2d +import torch.nn.functional as F + +# open CLIP +def resize_clip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): + # Rescale the grid of position embeddings when loading from state_dict + old_pos_embed = state_dict.get('visual.positional_embedding', None) + if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): + return + grid_size = to_2tuple(model.visual.grid_size) + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + return + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) + + logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode=interpolation, + align_corners=True, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + state_dict['visual.positional_embedding'] = new_pos_embed + + +def resize_visual_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): + # Rescale the grid of position embeddings when loading from state_dict + old_pos_embed = state_dict.get('positional_embedding', None) + if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): + return + grid_size = to_2tuple(model.visual.grid_size) + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + return + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) + + logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode=interpolation, + align_corners=True, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + state_dict['positional_embedding'] = new_pos_embed + +def resize_evaclip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): + all_keys = list(state_dict.keys()) + # interpolate position embedding + if 'visual.pos_embed' in state_dict: + pos_embed_checkpoint = state_dict['visual.pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.visual.patch_embed.num_patches + num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + state_dict['visual.pos_embed'] = new_pos_embed + + patch_embed_proj = state_dict['visual.patch_embed.proj.weight'] + patch_size = model.visual.patch_embed.patch_size + state_dict['visual.patch_embed.proj.weight'] = torch.nn.functional.interpolate( + patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False) + + +def resize_eva_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): + all_keys = list(state_dict.keys()) + # interpolate position embedding + if 'pos_embed' in state_dict: + pos_embed_checkpoint = state_dict['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.visual.patch_embed.num_patches + num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + state_dict['pos_embed'] = new_pos_embed + + patch_embed_proj = state_dict['patch_embed.proj.weight'] + patch_size = model.visual.patch_embed.patch_size + state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate( + patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False) + + +def resize_rel_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): + all_keys = list(state_dict.keys()) + for key in all_keys: + if "relative_position_index" in key: + state_dict.pop(key) + + if "relative_position_bias_table" in key: + rel_pos_bias = state_dict[key] + src_num_pos, num_attn_heads = rel_pos_bias.size() + dst_num_pos, _ = model.visual.state_dict()[key].size() + dst_patch_shape = model.visual.patch_embed.patch_shape + if dst_patch_shape[0] != dst_patch_shape[1]: + raise NotImplementedError() + num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) + src_size = int((src_num_pos - num_extra_tokens) ** 0.5) + dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) + if src_size != dst_size: + print("Position interpolate for %s from %dx%d to %dx%d" % ( + key, src_size, src_size, dst_size, dst_size)) + extra_tokens = rel_pos_bias[-num_extra_tokens:, :] + rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] + + def geometric_progression(a, r, n): + return a * (1.0 - r ** n) / (1.0 - r) + + left, right = 1.01, 1.5 + while right - left > 1e-6: + q = (left + right) / 2.0 + gp = geometric_progression(1, q, src_size // 2) + if gp > dst_size // 2: + right = q + else: + left = q + + # if q > 1.090307: + # q = 1.090307 + + dis = [] + cur = 1 + for i in range(src_size // 2): + dis.append(cur) + cur += q ** (i + 1) + + r_ids = [-_ for _ in reversed(dis)] + + x = r_ids + [0] + dis + y = r_ids + [0] + dis + + t = dst_size // 2.0 + dx = np.arange(-t, t + 0.1, 1.0) + dy = np.arange(-t, t + 0.1, 1.0) + + print("Original positions = %s" % str(x)) + print("Target positions = %s" % str(dx)) + + all_rel_pos_bias = [] + + for i in range(num_attn_heads): + z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() + f = F.interpolate.interp2d(x, y, z, kind='cubic') + all_rel_pos_bias.append( + torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device)) + + rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) + + new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) + state_dict[key] = new_rel_pos_bias + + # interpolate position embedding + if 'pos_embed' in state_dict: + pos_embed_checkpoint = state_dict['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.visual.patch_embed.num_patches + num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + state_dict['pos_embed'] = new_pos_embed + + patch_embed_proj = state_dict['patch_embed.proj.weight'] + patch_size = model.visual.patch_embed.patch_size + state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate( + patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False) + + +def freeze_batch_norm_2d(module, module_match={}, name=''): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + module_match (dict): Dictionary of full module names to freeze (all if empty) + name (str): Full module name (prefix) + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + is_match = True + if module_match: + is_match = name in module_match + if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for child_name, child in module.named_children(): + full_child_name = '.'.join([name, child_name]) if name else child_name + new_child = freeze_batch_norm_2d(child, module_match, full_child_name) + if new_child is not child: + res.add_module(child_name, new_child) + return res + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = lambda n, x: _ntuple(n)(x) + + +def is_logging(args): + def is_global_master(args): + return args.rank == 0 + + def is_local_master(args): + return args.local_rank == 0 + + def is_master(args, local=False): + return is_local_master(args) if local else is_global_master(args) + return is_master + + +class AllGather(torch.autograd.Function): + """An autograd function that performs allgather on a tensor. + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + + @staticmethod + def forward(ctx, tensor, rank, world_size): + tensors_gather = [torch.empty_like(tensor) for _ in range(world_size)] + torch.distributed.all_gather(tensors_gather, tensor) + ctx.rank = rank + ctx.batch_size = tensor.shape[0] + return torch.cat(tensors_gather, 0) + + @staticmethod + def backward(ctx, grad_output): + return ( + grad_output[ctx.batch_size * ctx.rank: ctx.batch_size * (ctx.rank + 1)], + None, + None + ) + +allgather = AllGather.apply \ No newline at end of file diff --git a/PuLID_ComfyUI/examples/PuLID_4-Step_lightning.json b/PuLID_ComfyUI/examples/PuLID_4-Step_lightning.json new file mode 100644 index 0000000000000000000000000000000000000000..00a6a695c5bcf6738d17fd410d264f96b48c4cec --- /dev/null +++ b/PuLID_ComfyUI/examples/PuLID_4-Step_lightning.json @@ -0,0 +1,631 @@ +{ + "last_node_id": 43, + "last_link_id": 128, + "nodes": [ + { + "id": 8, + "type": "VAEDecode", + "pos": [ + 1210, + -270 + ], + "size": { + "0": 140, + "1": 46 + }, + "flags": {}, + "order": 11, + "mode": 0, + "inputs": [ + { + "name": "samples", + "type": "LATENT", + "link": 7 + }, + { + "name": "vae", + "type": "VAE", + "link": 8 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 10 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAEDecode" + } + }, + { + "id": 19, + "type": "PulidEvaClipLoader", + "pos": [ + 130, + 120 + ], + "size": { + "0": 140, + "1": 26 + }, + "flags": {}, + "order": 0, + "mode": 0, + "outputs": [ + { + "name": "EVA_CLIP", + "type": "EVA_CLIP", + "links": [ + 81 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PulidEvaClipLoader" + } + }, + { + "id": 17, + "type": "PulidInsightFaceLoader", + "pos": [ + 60, + 190 + ], + "size": { + "0": 210, + "1": 58 + }, + "flags": {}, + "order": 1, + "mode": 0, + "outputs": [ + { + "name": "FACEANALYSIS", + "type": "FACEANALYSIS", + "links": [ + 82 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PulidInsightFaceLoader" + }, + "widgets_values": [ + "CPU" + ] + }, + { + "id": 23, + "type": "CLIPTextEncode", + "pos": [ + 330, + -260 + ], + "size": { + "0": 334.8077697753906, + "1": 189.35675048828125 + }, + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 94 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 34 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncode" + }, + "widgets_values": [ + "flaws in the eyes, flaws in the face, flaws, lowres, non-HDRi, low quality, worst quality,artifacts noise, text, watermark, glitch, deformed, mutated, ugly, disfigured, hands, low resolution, partially rendered objects, deformed or partially rendered eyes, deformed, deformed eyeballs, cross-eyed,blurry" + ] + }, + { + "id": 22, + "type": "CLIPTextEncode", + "pos": [ + 340, + -430 + ], + "size": { + "0": 315.23089599609375, + "1": 113.96450805664062 + }, + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 93 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 35 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncode" + }, + "widgets_values": [ + "portrait,cinematic,wolf ears,white hair" + ] + }, + { + "id": 10, + "type": "PreviewImage", + "pos": [ + 1230, + -160 + ], + "size": [ + 855.3022058439137, + 1107.2183523542942 + ], + "flags": {}, + "order": 12, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 10 + } + ], + "properties": { + "Node name for S&R": "PreviewImage" + } + }, + { + "id": 12, + "type": "LoadImage", + "pos": [ + -117, + 336 + ], + "size": { + "0": 404.07366943359375, + "1": 496.2817077636719 + }, + "flags": {}, + "order": 2, + "mode": 0, + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 114 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "monalisa.png", + "image" + ] + }, + { + "id": 5, + "type": "EmptyLatentImage", + "pos": [ + 353, + 286 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 3, + "mode": 0, + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 2 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "EmptyLatentImage" + }, + "widgets_values": [ + 768, + 1024, + 1 + ] + }, + { + "id": 16, + "type": "PulidModelLoader", + "pos": [ + -20, + 20 + ], + "size": { + "0": 304.0072021484375, + "1": 58 + }, + "flags": {}, + "order": 4, + "mode": 0, + "outputs": [ + { + "name": "PULID", + "type": "PULID", + "links": [ + 117 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PulidModelLoader" + }, + "widgets_values": [ + "ip-adapter_pulid_sdxl_fp16.safetensors" + ] + }, + { + "id": 4, + "type": "CheckpointLoaderSimple", + "pos": [ + -130, + -350 + ], + "size": { + "0": 319.03692626953125, + "1": 101.3391342163086 + }, + "flags": {}, + "order": 5, + "mode": 0, + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [], + "slot_index": 0 + }, + { + "name": "CLIP", + "type": "CLIP", + "links": [ + 93, + 94 + ], + "slot_index": 1 + }, + { + "name": "VAE", + "type": "VAE", + "links": [ + 8 + ], + "slot_index": 2 + } + ], + "properties": { + "Node name for S&R": "CheckpointLoaderSimple" + }, + "widgets_values": [ + "sdxl/sd_xl_base_1.0_0.9vae.safetensors" + ] + }, + { + "id": 41, + "type": "UNETLoader", + "pos": [ + -130, + -193 + ], + "size": { + "0": 315, + "1": 58 + }, + "flags": {}, + "order": 6, + "mode": 0, + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 128 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "UNETLoader" + }, + "widgets_values": [ + "sdxl_lightning_4step_unet.safetensors" + ] + }, + { + "id": 33, + "type": "ApplyPulid", + "pos": [ + 350, + -10 + ], + "size": { + "0": 315, + "1": 210 + }, + "flags": {}, + "order": 9, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 128 + }, + { + "name": "pulid", + "type": "PULID", + "link": 117 + }, + { + "name": "eva_clip", + "type": "EVA_CLIP", + "link": 81 + }, + { + "name": "face_analysis", + "type": "FACEANALYSIS", + "link": 82 + }, + { + "name": "image", + "type": "IMAGE", + "link": 114 + } + ], + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 120 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "ApplyPulid" + }, + "widgets_values": [ + "fidelity", + 0.8, + 0, + 1 + ] + }, + { + "id": 3, + "type": "KSampler", + "pos": [ + 800, + -270 + ], + "size": { + "0": 341.2750244140625, + "1": 262 + }, + "flags": {}, + "order": 10, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 120 + }, + { + "name": "positive", + "type": "CONDITIONING", + "link": 35 + }, + { + "name": "negative", + "type": "CONDITIONING", + "link": 34 + }, + { + "name": "latent_image", + "type": "LATENT", + "link": 2 + } + ], + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 7 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "KSampler" + }, + "widgets_values": [ + 42, + "fixed", + 4, + 1.2, + "dpmpp_2m", + "sgm_uniform", + 1 + ] + } + ], + "links": [ + [ + 2, + 5, + 0, + 3, + 3, + "LATENT" + ], + [ + 7, + 3, + 0, + 8, + 0, + "LATENT" + ], + [ + 8, + 4, + 2, + 8, + 1, + "VAE" + ], + [ + 10, + 8, + 0, + 10, + 0, + "IMAGE" + ], + [ + 34, + 23, + 0, + 3, + 2, + "CONDITIONING" + ], + [ + 35, + 22, + 0, + 3, + 1, + "CONDITIONING" + ], + [ + 81, + 19, + 0, + 33, + 2, + "EVA_CLIP" + ], + [ + 82, + 17, + 0, + 33, + 3, + "FACEANALYSIS" + ], + [ + 93, + 4, + 1, + 22, + 0, + "CLIP" + ], + [ + 94, + 4, + 1, + 23, + 0, + "CLIP" + ], + [ + 114, + 12, + 0, + 33, + 4, + "IMAGE" + ], + [ + 117, + 16, + 0, + 33, + 1, + "PULID" + ], + [ + 120, + 33, + 0, + 3, + 0, + "MODEL" + ], + [ + 128, + 41, + 0, + 33, + 0, + "MODEL" + ] + ], + "groups": [], + "config": {}, + "extra": {}, + "version": 0.4 +} \ No newline at end of file diff --git a/PuLID_ComfyUI/examples/PuLID_IPAdapter_style_transfer.json b/PuLID_ComfyUI/examples/PuLID_IPAdapter_style_transfer.json new file mode 100644 index 0000000000000000000000000000000000000000..5b6f9f556b5dae48a55df8a1021330477e1603d7 --- /dev/null +++ b/PuLID_ComfyUI/examples/PuLID_IPAdapter_style_transfer.json @@ -0,0 +1,794 @@ +{ + "last_node_id": 48, + "last_link_id": 139, + "nodes": [ + { + "id": 19, + "type": "PulidEvaClipLoader", + "pos": [ + 130, + 120 + ], + "size": { + "0": 140, + "1": 26 + }, + "flags": {}, + "order": 0, + "mode": 0, + "outputs": [ + { + "name": "EVA_CLIP", + "type": "EVA_CLIP", + "links": [ + 81 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PulidEvaClipLoader" + } + }, + { + "id": 17, + "type": "PulidInsightFaceLoader", + "pos": [ + 60, + 190 + ], + "size": { + "0": 210, + "1": 58 + }, + "flags": {}, + "order": 1, + "mode": 0, + "outputs": [ + { + "name": "FACEANALYSIS", + "type": "FACEANALYSIS", + "links": [ + 82 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PulidInsightFaceLoader" + }, + "widgets_values": [ + "CPU" + ] + }, + { + "id": 16, + "type": "PulidModelLoader", + "pos": [ + -20, + 20 + ], + "size": { + "0": 304.0072021484375, + "1": 58 + }, + "flags": {}, + "order": 2, + "mode": 0, + "outputs": [ + { + "name": "PULID", + "type": "PULID", + "links": [ + 117 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PulidModelLoader" + }, + "widgets_values": [ + "ip-adapter_pulid_sdxl_fp16.safetensors" + ] + }, + { + "id": 5, + "type": "EmptyLatentImage", + "pos": [ + 350, + 265 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 3, + "mode": 0, + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 2 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "EmptyLatentImage" + }, + "widgets_values": [ + 768, + 1024, + 1 + ] + }, + { + "id": 23, + "type": "CLIPTextEncode", + "pos": [ + 330, + -260 + ], + "size": { + "0": 334.8077697753906, + "1": 189.35675048828125 + }, + "flags": {}, + "order": 9, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 94 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 34 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncode" + }, + "widgets_values": [ + "blurry, malformed, low quality, worst quality, artifacts, noise, text, watermark, glitch, deformed, ugly, horror, ill" + ] + }, + { + "id": 22, + "type": "CLIPTextEncode", + "pos": [ + 340, + -430 + ], + "size": { + "0": 315.23089599609375, + "1": 113.96450805664062 + }, + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 93 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 35 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncode" + }, + "widgets_values": [ + "closeup portrait, cyberpunk, cinematic, hoodie, purple hair, highly detailed, 4k, high resolution" + ] + }, + { + "id": 12, + "type": "LoadImage", + "pos": [ + -115, + 310 + ], + "size": { + "0": 404.07366943359375, + "1": 496.2817077636719 + }, + "flags": {}, + "order": 4, + "mode": 0, + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 114 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "monalisa.png", + "image" + ] + }, + { + "id": 4, + "type": "CheckpointLoaderSimple", + "pos": [ + -97, + -265 + ], + "size": { + "0": 319.03692626953125, + "1": 101.3391342163086 + }, + "flags": {}, + "order": 5, + "mode": 0, + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 133 + ], + "slot_index": 0 + }, + { + "name": "CLIP", + "type": "CLIP", + "links": [ + 93, + 94 + ], + "slot_index": 1 + }, + { + "name": "VAE", + "type": "VAE", + "links": [ + 8 + ], + "slot_index": 2 + } + ], + "properties": { + "Node name for S&R": "CheckpointLoaderSimple" + }, + "widgets_values": [ + "sdxl/Proteus-RunDiffusion.safetensors" + ] + }, + { + "id": 33, + "type": "ApplyPulid", + "pos": [ + 350, + -10 + ], + "size": { + "0": 315, + "1": 210 + }, + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 133 + }, + { + "name": "pulid", + "type": "PULID", + "link": 117 + }, + { + "name": "eva_clip", + "type": "EVA_CLIP", + "link": 81 + }, + { + "name": "face_analysis", + "type": "FACEANALYSIS", + "link": 82 + }, + { + "name": "image", + "type": "IMAGE", + "link": 114 + } + ], + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 136 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "ApplyPulid" + }, + "widgets_values": [ + "fidelity", + 0.8, + 0, + 1 + ] + }, + { + "id": 47, + "type": "IPAdapterUnifiedLoader", + "pos": [ + 720, + -10 + ], + "size": [ + 245.09423828124943, + 78 + ], + "flags": {}, + "order": 10, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 136 + }, + { + "name": "ipadapter", + "type": "IPADAPTER", + "link": null + } + ], + "outputs": [ + { + "name": "model", + "type": "MODEL", + "links": [ + 137 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "ipadapter", + "type": "IPADAPTER", + "links": [ + 135 + ], + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "IPAdapterUnifiedLoader" + }, + "widgets_values": [ + "PLUS (high strength)" + ] + }, + { + "id": 8, + "type": "VAEDecode", + "pos": [ + 1831, + 16 + ], + "size": { + "0": 140, + "1": 46 + }, + "flags": {}, + "order": 13, + "mode": 0, + "inputs": [ + { + "name": "samples", + "type": "LATENT", + "link": 7 + }, + { + "name": "vae", + "type": "VAE", + "link": 8 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 10 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAEDecode" + } + }, + { + "id": 10, + "type": "PreviewImage", + "pos": [ + 1817, + 123 + ], + "size": [ + 705.6038401281248, + 950.4616015812499 + ], + "flags": {}, + "order": 14, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 10 + } + ], + "properties": { + "Node name for S&R": "PreviewImage" + } + }, + { + "id": 46, + "type": "IPAdapterAdvanced", + "pos": [ + 1033, + -36 + ], + "size": { + "0": 315, + "1": 278 + }, + "flags": {}, + "order": 11, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 137 + }, + { + "name": "ipadapter", + "type": "IPADAPTER", + "link": 135, + "slot_index": 1 + }, + { + "name": "image", + "type": "IMAGE", + "link": 139, + "slot_index": 2 + }, + { + "name": "image_negative", + "type": "IMAGE", + "link": null + }, + { + "name": "attn_mask", + "type": "MASK", + "link": null + }, + { + "name": "clip_vision", + "type": "CLIP_VISION", + "link": null + } + ], + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 138 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "IPAdapterAdvanced" + }, + "widgets_values": [ + 1, + "style transfer", + "concat", + 0, + 1, + "V only" + ] + }, + { + "id": 48, + "type": "LoadImage", + "pos": [ + 1032, + 303 + ], + "size": [ + 315, + 314 + ], + "flags": {}, + "order": 6, + "mode": 0, + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 139 + ], + "shape": 3 + }, + { + "name": "MASK", + "type": "MASK", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "anime_illustration.png", + "image" + ] + }, + { + "id": 3, + "type": "KSampler", + "pos": [ + 1413, + 12 + ], + "size": { + "0": 341.2750244140625, + "1": 262 + }, + "flags": {}, + "order": 12, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 138 + }, + { + "name": "positive", + "type": "CONDITIONING", + "link": 35 + }, + { + "name": "negative", + "type": "CONDITIONING", + "link": 34 + }, + { + "name": "latent_image", + "type": "LATENT", + "link": 2 + } + ], + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 7 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "KSampler" + }, + "widgets_values": [ + 52, + "fixed", + 30, + 6, + "dpmpp_2m", + "sgm_uniform", + 1 + ] + } + ], + "links": [ + [ + 2, + 5, + 0, + 3, + 3, + "LATENT" + ], + [ + 7, + 3, + 0, + 8, + 0, + "LATENT" + ], + [ + 8, + 4, + 2, + 8, + 1, + "VAE" + ], + [ + 10, + 8, + 0, + 10, + 0, + "IMAGE" + ], + [ + 34, + 23, + 0, + 3, + 2, + "CONDITIONING" + ], + [ + 35, + 22, + 0, + 3, + 1, + "CONDITIONING" + ], + [ + 81, + 19, + 0, + 33, + 2, + "EVA_CLIP" + ], + [ + 82, + 17, + 0, + 33, + 3, + "FACEANALYSIS" + ], + [ + 93, + 4, + 1, + 22, + 0, + "CLIP" + ], + [ + 94, + 4, + 1, + 23, + 0, + "CLIP" + ], + [ + 114, + 12, + 0, + 33, + 4, + "IMAGE" + ], + [ + 117, + 16, + 0, + 33, + 1, + "PULID" + ], + [ + 133, + 4, + 0, + 33, + 0, + "MODEL" + ], + [ + 135, + 47, + 1, + 46, + 1, + "IPADAPTER" + ], + [ + 136, + 33, + 0, + 47, + 0, + "MODEL" + ], + [ + 137, + 47, + 0, + 46, + 0, + "MODEL" + ], + [ + 138, + 46, + 0, + 3, + 0, + "MODEL" + ], + [ + 139, + 48, + 0, + 46, + 2, + "IMAGE" + ] + ], + "groups": [], + "config": {}, + "extra": {}, + "version": 0.4 +} \ No newline at end of file diff --git a/PuLID_ComfyUI/examples/PuLID_attention_mask.json b/PuLID_ComfyUI/examples/PuLID_attention_mask.json new file mode 100644 index 0000000000000000000000000000000000000000..305aa567d98da1c2c3d803209d19b5fc0faa5d0f --- /dev/null +++ b/PuLID_ComfyUI/examples/PuLID_attention_mask.json @@ -0,0 +1,946 @@ +{ + "last_node_id": 88, + "last_link_id": 248, + "nodes": [ + { + "id": 5, + "type": "EmptyLatentImage", + "pos": [ + 350, + 265 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 0, + "mode": 0, + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 2 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "EmptyLatentImage" + }, + "widgets_values": [ + 1280, + 960, + 1 + ] + }, + { + "id": 33, + "type": "ApplyPulid", + "pos": [ + 350, + -10 + ], + "size": { + "0": 315, + "1": 230 + }, + "flags": {}, + "order": 13, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 133 + }, + { + "name": "pulid", + "type": "PULID", + "link": 117 + }, + { + "name": "eva_clip", + "type": "EVA_CLIP", + "link": 81 + }, + { + "name": "face_analysis", + "type": "FACEANALYSIS", + "link": 82 + }, + { + "name": "image", + "type": "IMAGE", + "link": 114 + }, + { + "name": "attn_mask", + "type": "MASK", + "link": 247 + } + ], + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 141 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "ApplyPulid" + }, + "widgets_values": [ + "fidelity", + 0.7000000000000001, + 0, + 1 + ] + }, + { + "id": 85, + "type": "SolidMask", + "pos": [ + -307, + 584 + ], + "size": [ + 210, + 106 + ], + "flags": {}, + "order": 1, + "mode": 0, + "outputs": [ + { + "name": "MASK", + "type": "MASK", + "links": [ + 244 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "SolidMask" + }, + "widgets_values": [ + 0, + 1280, + 960 + ] + }, + { + "id": 49, + "type": "LoadImage", + "pos": [ + 407, + 550 + ], + "size": [ + 248.03589794921936, + 339.7795556640626 + ], + "flags": {}, + "order": 2, + "mode": 0, + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 145 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "venere.jpg", + "image" + ] + }, + { + "id": 48, + "type": "InvertMask", + "pos": [ + 526, + 438 + ], + "size": [ + 140, + 26 + ], + "flags": {}, + "order": 12, + "mode": 0, + "inputs": [ + { + "name": "mask", + "type": "MASK", + "link": 246 + } + ], + "outputs": [ + { + "name": "MASK", + "type": "MASK", + "links": [ + 151 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "InvertMask" + } + }, + { + "id": 8, + "type": "VAEDecode", + "pos": [ + 1575, + 160 + ], + "size": { + "0": 140, + "1": 46 + }, + "flags": {}, + "order": 16, + "mode": 0, + "inputs": [ + { + "name": "samples", + "type": "LATENT", + "link": 7 + }, + { + "name": "vae", + "type": "VAE", + "link": 8 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 10 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAEDecode" + } + }, + { + "id": 10, + "type": "PreviewImage", + "pos": [ + 1592, + 279 + ], + "size": [ + 1370.7157657734379, + 1041.8039240156252 + ], + "flags": {}, + "order": 17, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 10 + } + ], + "properties": { + "Node name for S&R": "PreviewImage" + } + }, + { + "id": 16, + "type": "PulidModelLoader", + "pos": [ + -111, + -181 + ], + "size": { + "0": 304.0072021484375, + "1": 58 + }, + "flags": {}, + "order": 3, + "mode": 0, + "outputs": [ + { + "name": "PULID", + "type": "PULID", + "links": [ + 117, + 136 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PulidModelLoader" + }, + "widgets_values": [ + "ip-adapter_pulid_sdxl_fp16.safetensors" + ] + }, + { + "id": 19, + "type": "PulidEvaClipLoader", + "pos": [ + 54, + -69 + ], + "size": { + "0": 140, + "1": 26 + }, + "flags": {}, + "order": 4, + "mode": 0, + "outputs": [ + { + "name": "EVA_CLIP", + "type": "EVA_CLIP", + "links": [ + 81, + 137 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PulidEvaClipLoader" + } + }, + { + "id": 17, + "type": "PulidInsightFaceLoader", + "pos": [ + -18, + 12 + ], + "size": { + "0": 210, + "1": 58 + }, + "flags": {}, + "order": 5, + "mode": 0, + "outputs": [ + { + "name": "FACEANALYSIS", + "type": "FACEANALYSIS", + "links": [ + 82, + 138 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PulidInsightFaceLoader" + }, + "widgets_values": [ + "CPU" + ] + }, + { + "id": 12, + "type": "LoadImage", + "pos": [ + -34, + 145 + ], + "size": [ + 261.645185990767, + 346.38255171342325 + ], + "flags": {}, + "order": 6, + "mode": 0, + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 114 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "monalisa.png", + "image" + ] + }, + { + "id": 87, + "type": "MaskComposite", + "pos": [ + 15, + 546 + ], + "size": [ + 210, + 126 + ], + "flags": {}, + "order": 9, + "mode": 0, + "inputs": [ + { + "name": "destination", + "type": "MASK", + "link": 244 + }, + { + "name": "source", + "type": "MASK", + "link": 245 + } + ], + "outputs": [ + { + "name": "MASK", + "type": "MASK", + "links": [ + 246, + 247 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "MaskComposite" + }, + "widgets_values": [ + 0, + 0, + "add" + ] + }, + { + "id": 86, + "type": "SolidMask", + "pos": [ + -304, + 747 + ], + "size": { + "0": 210, + "1": 106 + }, + "flags": {}, + "order": 7, + "mode": 0, + "outputs": [ + { + "name": "MASK", + "type": "MASK", + "links": [ + 245 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "SolidMask" + }, + "widgets_values": [ + 1, + 640, + 960 + ] + }, + { + "id": 23, + "type": "CLIPTextEncode", + "pos": [ + 756, + -47 + ], + "size": [ + 316.32471195096673, + 101.97065006593618 + ], + "flags": {}, + "order": 10, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 94 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 34 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncode" + }, + "widgets_values": [ + "blurry, malformed, low quality, worst quality, artifacts, noise, text, watermark, glitch, deformed, ugly, horror, ill" + ] + }, + { + "id": 47, + "type": "ApplyPulid", + "pos": [ + 765, + 128 + ], + "size": { + "0": 315, + "1": 230 + }, + "flags": {}, + "order": 14, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 141 + }, + { + "name": "pulid", + "type": "PULID", + "link": 136 + }, + { + "name": "eva_clip", + "type": "EVA_CLIP", + "link": 137 + }, + { + "name": "face_analysis", + "type": "FACEANALYSIS", + "link": 138 + }, + { + "name": "image", + "type": "IMAGE", + "link": 145 + }, + { + "name": "attn_mask", + "type": "MASK", + "link": 151 + } + ], + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 142 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "ApplyPulid" + }, + "widgets_values": [ + "fidelity", + 0.7000000000000001, + 0, + 1 + ] + }, + { + "id": 55, + "type": "CLIPTextEncode", + "pos": [ + 755, + -211 + ], + "size": { + "0": 315.23089599609375, + "1": 113.96450805664062 + }, + "flags": {}, + "order": 11, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 156 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 160 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncode" + }, + "widgets_values": [ + "closeup two girl friends on the streets of a cyberpunk city, cinematic, hoodie, multicolored hair, highly detailed, 4k, high resolution" + ] + }, + { + "id": 3, + "type": "KSampler", + "pos": [ + 1162, + 38 + ], + "size": { + "0": 341.2750244140625, + "1": 262 + }, + "flags": {}, + "order": 15, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 142 + }, + { + "name": "positive", + "type": "CONDITIONING", + "link": 160 + }, + { + "name": "negative", + "type": "CONDITIONING", + "link": 34 + }, + { + "name": "latent_image", + "type": "LATENT", + "link": 2 + } + ], + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 7 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "KSampler" + }, + "widgets_values": [ + 70, + "fixed", + 30, + 6, + "dpmpp_2m", + "karras", + 1 + ] + }, + { + "id": 4, + "type": "CheckpointLoaderSimple", + "pos": [ + -131, + -342 + ], + "size": { + "0": 319.03692626953125, + "1": 101.3391342163086 + }, + "flags": {}, + "order": 8, + "mode": 0, + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 133 + ], + "slot_index": 0 + }, + { + "name": "CLIP", + "type": "CLIP", + "links": [ + 94, + 156 + ], + "slot_index": 1 + }, + { + "name": "VAE", + "type": "VAE", + "links": [ + 8 + ], + "slot_index": 2 + } + ], + "properties": { + "Node name for S&R": "CheckpointLoaderSimple" + }, + "widgets_values": [ + "sdxl/AlbedoBaseXL.safetensors" + ] + } + ], + "links": [ + [ + 2, + 5, + 0, + 3, + 3, + "LATENT" + ], + [ + 7, + 3, + 0, + 8, + 0, + "LATENT" + ], + [ + 8, + 4, + 2, + 8, + 1, + "VAE" + ], + [ + 10, + 8, + 0, + 10, + 0, + "IMAGE" + ], + [ + 34, + 23, + 0, + 3, + 2, + "CONDITIONING" + ], + [ + 81, + 19, + 0, + 33, + 2, + "EVA_CLIP" + ], + [ + 82, + 17, + 0, + 33, + 3, + "FACEANALYSIS" + ], + [ + 94, + 4, + 1, + 23, + 0, + "CLIP" + ], + [ + 114, + 12, + 0, + 33, + 4, + "IMAGE" + ], + [ + 117, + 16, + 0, + 33, + 1, + "PULID" + ], + [ + 133, + 4, + 0, + 33, + 0, + "MODEL" + ], + [ + 136, + 16, + 0, + 47, + 1, + "PULID" + ], + [ + 137, + 19, + 0, + 47, + 2, + "EVA_CLIP" + ], + [ + 138, + 17, + 0, + 47, + 3, + "FACEANALYSIS" + ], + [ + 141, + 33, + 0, + 47, + 0, + "MODEL" + ], + [ + 142, + 47, + 0, + 3, + 0, + "MODEL" + ], + [ + 145, + 49, + 0, + 47, + 4, + "IMAGE" + ], + [ + 151, + 48, + 0, + 47, + 5, + "MASK" + ], + [ + 156, + 4, + 1, + 55, + 0, + "CLIP" + ], + [ + 160, + 55, + 0, + 3, + 1, + "CONDITIONING" + ], + [ + 244, + 85, + 0, + 87, + 0, + "MASK" + ], + [ + 245, + 86, + 0, + 87, + 1, + "MASK" + ], + [ + 246, + 87, + 0, + 48, + 0, + "MASK" + ], + [ + 247, + 87, + 0, + 33, + 5, + "MASK" + ] + ], + "groups": [], + "config": {}, + "extra": {}, + "version": 0.4 +} \ No newline at end of file diff --git a/PuLID_ComfyUI/examples/PuLID_lightning_lora.json b/PuLID_ComfyUI/examples/PuLID_lightning_lora.json new file mode 100644 index 0000000000000000000000000000000000000000..a817888dc31e63259f9c196c3b10523c054dc774 --- /dev/null +++ b/PuLID_ComfyUI/examples/PuLID_lightning_lora.json @@ -0,0 +1,649 @@ +{ + "last_node_id": 45, + "last_link_id": 132, + "nodes": [ + { + "id": 8, + "type": "VAEDecode", + "pos": [ + 1210, + -270 + ], + "size": { + "0": 140, + "1": 46 + }, + "flags": {}, + "order": 11, + "mode": 0, + "inputs": [ + { + "name": "samples", + "type": "LATENT", + "link": 7 + }, + { + "name": "vae", + "type": "VAE", + "link": 8 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 10 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAEDecode" + } + }, + { + "id": 19, + "type": "PulidEvaClipLoader", + "pos": [ + 130, + 120 + ], + "size": { + "0": 140, + "1": 26 + }, + "flags": {}, + "order": 0, + "mode": 0, + "outputs": [ + { + "name": "EVA_CLIP", + "type": "EVA_CLIP", + "links": [ + 81 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PulidEvaClipLoader" + } + }, + { + "id": 17, + "type": "PulidInsightFaceLoader", + "pos": [ + 60, + 190 + ], + "size": { + "0": 210, + "1": 58 + }, + "flags": {}, + "order": 1, + "mode": 0, + "outputs": [ + { + "name": "FACEANALYSIS", + "type": "FACEANALYSIS", + "links": [ + 82 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PulidInsightFaceLoader" + }, + "widgets_values": [ + "CPU" + ] + }, + { + "id": 23, + "type": "CLIPTextEncode", + "pos": [ + 330, + -260 + ], + "size": { + "0": 334.8077697753906, + "1": 189.35675048828125 + }, + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 94 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 34 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncode" + }, + "widgets_values": [ + "flaws in the eyes, flaws in the face, flaws, lowres, non-HDRi, low quality, worst quality,artifacts noise, text, watermark, glitch, deformed, mutated, ugly, disfigured, hands, low resolution, partially rendered objects, deformed or partially rendered eyes, deformed, deformed eyeballs, cross-eyed,blurry" + ] + }, + { + "id": 22, + "type": "CLIPTextEncode", + "pos": [ + 340, + -430 + ], + "size": { + "0": 315.23089599609375, + "1": 113.96450805664062 + }, + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 93 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 35 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncode" + }, + "widgets_values": [ + "portrait,cinematic,wolf ears,white hair" + ] + }, + { + "id": 10, + "type": "PreviewImage", + "pos": [ + 1230, + -160 + ], + "size": [ + 855.3022058439137, + 1107.2183523542942 + ], + "flags": {}, + "order": 12, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 10 + } + ], + "properties": { + "Node name for S&R": "PreviewImage" + } + }, + { + "id": 12, + "type": "LoadImage", + "pos": [ + -117, + 336 + ], + "size": { + "0": 404.07366943359375, + "1": 496.2817077636719 + }, + "flags": {}, + "order": 2, + "mode": 0, + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 114 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "monalisa.png", + "image" + ] + }, + { + "id": 16, + "type": "PulidModelLoader", + "pos": [ + -20, + 20 + ], + "size": { + "0": 304.0072021484375, + "1": 58 + }, + "flags": {}, + "order": 3, + "mode": 0, + "outputs": [ + { + "name": "PULID", + "type": "PULID", + "links": [ + 117 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PulidModelLoader" + }, + "widgets_values": [ + "ip-adapter_pulid_sdxl_fp16.safetensors" + ] + }, + { + "id": 45, + "type": "LoraLoaderModelOnly", + "pos": [ + 4, + -328 + ], + "size": [ + 267.767924663449, + 82 + ], + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 129 + } + ], + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 131 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "LoraLoaderModelOnly" + }, + "widgets_values": [ + "sdxl_lightning_4step_lora.safetensors", + 1 + ] + }, + { + "id": 5, + "type": "EmptyLatentImage", + "pos": [ + 350, + 265 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 4, + "mode": 0, + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 2 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "EmptyLatentImage" + }, + "widgets_values": [ + 768, + 1024, + 1 + ] + }, + { + "id": 33, + "type": "ApplyPulid", + "pos": [ + 350, + -10 + ], + "size": { + "0": 315, + "1": 210 + }, + "flags": {}, + "order": 9, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 131 + }, + { + "name": "pulid", + "type": "PULID", + "link": 117 + }, + { + "name": "eva_clip", + "type": "EVA_CLIP", + "link": 81 + }, + { + "name": "face_analysis", + "type": "FACEANALYSIS", + "link": 82 + }, + { + "name": "image", + "type": "IMAGE", + "link": 114 + } + ], + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 132 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "ApplyPulid" + }, + "widgets_values": [ + "fidelity", + 0.8, + 0, + 1 + ] + }, + { + "id": 4, + "type": "CheckpointLoaderSimple", + "pos": [ + -378, + -329 + ], + "size": { + "0": 319.03692626953125, + "1": 101.3391342163086 + }, + "flags": {}, + "order": 5, + "mode": 0, + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 129 + ], + "slot_index": 0 + }, + { + "name": "CLIP", + "type": "CLIP", + "links": [ + 93, + 94 + ], + "slot_index": 1 + }, + { + "name": "VAE", + "type": "VAE", + "links": [ + 8 + ], + "slot_index": 2 + } + ], + "properties": { + "Node name for S&R": "CheckpointLoaderSimple" + }, + "widgets_values": [ + "sdxl/juggernautXL_version8Rundiffusion.safetensors" + ] + }, + { + "id": 3, + "type": "KSampler", + "pos": [ + 800, + -270 + ], + "size": { + "0": 341.2750244140625, + "1": 262 + }, + "flags": {}, + "order": 10, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 132 + }, + { + "name": "positive", + "type": "CONDITIONING", + "link": 35 + }, + { + "name": "negative", + "type": "CONDITIONING", + "link": 34 + }, + { + "name": "latent_image", + "type": "LATENT", + "link": 2 + } + ], + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 7 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "KSampler" + }, + "widgets_values": [ + 42, + "fixed", + 4, + 1.2, + "dpmpp_2m", + "sgm_uniform", + 1 + ] + } + ], + "links": [ + [ + 2, + 5, + 0, + 3, + 3, + "LATENT" + ], + [ + 7, + 3, + 0, + 8, + 0, + "LATENT" + ], + [ + 8, + 4, + 2, + 8, + 1, + "VAE" + ], + [ + 10, + 8, + 0, + 10, + 0, + "IMAGE" + ], + [ + 34, + 23, + 0, + 3, + 2, + "CONDITIONING" + ], + [ + 35, + 22, + 0, + 3, + 1, + "CONDITIONING" + ], + [ + 81, + 19, + 0, + 33, + 2, + "EVA_CLIP" + ], + [ + 82, + 17, + 0, + 33, + 3, + "FACEANALYSIS" + ], + [ + 93, + 4, + 1, + 22, + 0, + "CLIP" + ], + [ + 94, + 4, + 1, + 23, + 0, + "CLIP" + ], + [ + 114, + 12, + 0, + 33, + 4, + "IMAGE" + ], + [ + 117, + 16, + 0, + 33, + 1, + "PULID" + ], + [ + 129, + 4, + 0, + 45, + 0, + "MODEL" + ], + [ + 131, + 45, + 0, + 33, + 0, + "MODEL" + ], + [ + 132, + 33, + 0, + 3, + 0, + "MODEL" + ] + ], + "groups": [], + "config": {}, + "extra": {}, + "version": 0.4 +} \ No newline at end of file diff --git a/PuLID_ComfyUI/examples/PuLID_simple.json b/PuLID_ComfyUI/examples/PuLID_simple.json new file mode 100644 index 0000000000000000000000000000000000000000..ae770ed4a575a6580e207d8c159812c32c026b0b --- /dev/null +++ b/PuLID_ComfyUI/examples/PuLID_simple.json @@ -0,0 +1,601 @@ +{ + "last_node_id": 45, + "last_link_id": 133, + "nodes": [ + { + "id": 8, + "type": "VAEDecode", + "pos": [ + 1210, + -270 + ], + "size": { + "0": 140, + "1": 46 + }, + "flags": {}, + "order": 10, + "mode": 0, + "inputs": [ + { + "name": "samples", + "type": "LATENT", + "link": 7 + }, + { + "name": "vae", + "type": "VAE", + "link": 8 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 10 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAEDecode" + } + }, + { + "id": 19, + "type": "PulidEvaClipLoader", + "pos": [ + 130, + 120 + ], + "size": { + "0": 140, + "1": 26 + }, + "flags": {}, + "order": 0, + "mode": 0, + "outputs": [ + { + "name": "EVA_CLIP", + "type": "EVA_CLIP", + "links": [ + 81 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PulidEvaClipLoader" + } + }, + { + "id": 17, + "type": "PulidInsightFaceLoader", + "pos": [ + 60, + 190 + ], + "size": { + "0": 210, + "1": 58 + }, + "flags": {}, + "order": 1, + "mode": 0, + "outputs": [ + { + "name": "FACEANALYSIS", + "type": "FACEANALYSIS", + "links": [ + 82 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PulidInsightFaceLoader" + }, + "widgets_values": [ + "CPU" + ] + }, + { + "id": 16, + "type": "PulidModelLoader", + "pos": [ + -20, + 20 + ], + "size": { + "0": 304.0072021484375, + "1": 58 + }, + "flags": {}, + "order": 2, + "mode": 0, + "outputs": [ + { + "name": "PULID", + "type": "PULID", + "links": [ + 117 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PulidModelLoader" + }, + "widgets_values": [ + "ip-adapter_pulid_sdxl_fp16.safetensors" + ] + }, + { + "id": 5, + "type": "EmptyLatentImage", + "pos": [ + 350, + 265 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 3, + "mode": 0, + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 2 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "EmptyLatentImage" + }, + "widgets_values": [ + 768, + 1024, + 1 + ] + }, + { + "id": 33, + "type": "ApplyPulid", + "pos": [ + 350, + -10 + ], + "size": { + "0": 315, + "1": 210 + }, + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 133 + }, + { + "name": "pulid", + "type": "PULID", + "link": 117 + }, + { + "name": "eva_clip", + "type": "EVA_CLIP", + "link": 81 + }, + { + "name": "face_analysis", + "type": "FACEANALYSIS", + "link": 82 + }, + { + "name": "image", + "type": "IMAGE", + "link": 114 + } + ], + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 132 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "ApplyPulid" + }, + "widgets_values": [ + "fidelity", + 0.8, + 0, + 1 + ] + }, + { + "id": 23, + "type": "CLIPTextEncode", + "pos": [ + 330, + -260 + ], + "size": { + "0": 334.8077697753906, + "1": 189.35675048828125 + }, + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 94 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 34 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncode" + }, + "widgets_values": [ + "blurry, malformed, low quality, worst quality, artifacts, noise, text, watermark, glitch, deformed, ugly, horror, ill" + ] + }, + { + "id": 22, + "type": "CLIPTextEncode", + "pos": [ + 340, + -430 + ], + "size": { + "0": 315.23089599609375, + "1": 113.96450805664062 + }, + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 93 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 35 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncode" + }, + "widgets_values": [ + "closeup portrait, cyberpunk, cinematic, hoodie, purple hair, highly detailed, 4k, high resolution" + ] + }, + { + "id": 3, + "type": "KSampler", + "pos": [ + 800, + -270 + ], + "size": { + "0": 341.2750244140625, + "1": 262 + }, + "flags": {}, + "order": 9, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 132 + }, + { + "name": "positive", + "type": "CONDITIONING", + "link": 35 + }, + { + "name": "negative", + "type": "CONDITIONING", + "link": 34 + }, + { + "name": "latent_image", + "type": "LATENT", + "link": 2 + } + ], + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 7 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "KSampler" + }, + "widgets_values": [ + 51, + "fixed", + 30, + 6, + "dpmpp_2m", + "sgm_uniform", + 1 + ] + }, + { + "id": 12, + "type": "LoadImage", + "pos": [ + -115, + 310 + ], + "size": { + "0": 404.07366943359375, + "1": 496.2817077636719 + }, + "flags": {}, + "order": 4, + "mode": 0, + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 114 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "monalisa.png", + "image" + ] + }, + { + "id": 4, + "type": "CheckpointLoaderSimple", + "pos": [ + -97, + -265 + ], + "size": { + "0": 319.03692626953125, + "1": 101.3391342163086 + }, + "flags": {}, + "order": 5, + "mode": 0, + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 133 + ], + "slot_index": 0 + }, + { + "name": "CLIP", + "type": "CLIP", + "links": [ + 93, + 94 + ], + "slot_index": 1 + }, + { + "name": "VAE", + "type": "VAE", + "links": [ + 8 + ], + "slot_index": 2 + } + ], + "properties": { + "Node name for S&R": "CheckpointLoaderSimple" + }, + "widgets_values": [ + "sdxl/Proteus-RunDiffusion.safetensors" + ] + }, + { + "id": 10, + "type": "PreviewImage", + "pos": [ + 1181, + -162 + ], + "size": [ + 705.6038401281248, + 950.4616015812499 + ], + "flags": {}, + "order": 11, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 10 + } + ], + "properties": { + "Node name for S&R": "PreviewImage" + } + } + ], + "links": [ + [ + 2, + 5, + 0, + 3, + 3, + "LATENT" + ], + [ + 7, + 3, + 0, + 8, + 0, + "LATENT" + ], + [ + 8, + 4, + 2, + 8, + 1, + "VAE" + ], + [ + 10, + 8, + 0, + 10, + 0, + "IMAGE" + ], + [ + 34, + 23, + 0, + 3, + 2, + "CONDITIONING" + ], + [ + 35, + 22, + 0, + 3, + 1, + "CONDITIONING" + ], + [ + 81, + 19, + 0, + 33, + 2, + "EVA_CLIP" + ], + [ + 82, + 17, + 0, + 33, + 3, + "FACEANALYSIS" + ], + [ + 93, + 4, + 1, + 22, + 0, + "CLIP" + ], + [ + 94, + 4, + 1, + 23, + 0, + "CLIP" + ], + [ + 114, + 12, + 0, + 33, + 4, + "IMAGE" + ], + [ + 117, + 16, + 0, + 33, + 1, + "PULID" + ], + [ + 132, + 33, + 0, + 3, + 0, + "MODEL" + ], + [ + 133, + 4, + 0, + 33, + 0, + "MODEL" + ] + ], + "groups": [], + "config": {}, + "extra": {}, + "version": 0.4 +} \ No newline at end of file diff --git a/PuLID_ComfyUI/examples/pulid_wf.jpg b/PuLID_ComfyUI/examples/pulid_wf.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0fade0000aee5849aaac15bae505b3029ea7bc99 --- /dev/null +++ b/PuLID_ComfyUI/examples/pulid_wf.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb945f14a747a03cfbacf20d3ef3be2f3d9b1b60757ce542f35e21ac8d922180 +size 133028 diff --git a/PuLID_ComfyUI/pulid.py b/PuLID_ComfyUI/pulid.py new file mode 100644 index 0000000000000000000000000000000000000000..e603f74907065c6621ec5dc8d4c2f00e9ef78776 --- /dev/null +++ b/PuLID_ComfyUI/pulid.py @@ -0,0 +1,492 @@ +import torch +from torch import nn +import torchvision.transforms as T +import torch.nn.functional as F +import os +import math +import folder_paths +import comfy.utils +from insightface.app import FaceAnalysis +from facexlib.parsing import init_parsing_model +from facexlib.utils.face_restoration_helper import FaceRestoreHelper +from comfy.ldm.modules.attention import optimized_attention + +from .eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD + +from .encoders import IDEncoder + +INSIGHTFACE_DIR = os.path.join(folder_paths.models_dir, "insightface") + +MODELS_DIR = os.path.join(folder_paths.models_dir, "pulid") +if "pulid" not in folder_paths.folder_names_and_paths: + current_paths = [MODELS_DIR] +else: + current_paths, _ = folder_paths.folder_names_and_paths["pulid"] +folder_paths.folder_names_and_paths["pulid"] = (current_paths, folder_paths.supported_pt_extensions) + +class PulidModel(nn.Module): + def __init__(self, model): + super().__init__() + + self.model = model + self.image_proj_model = self.init_id_adapter() + self.image_proj_model.load_state_dict(model["image_proj"]) + self.ip_layers = To_KV(model["ip_adapter"]) + + def init_id_adapter(self): + image_proj_model = IDEncoder() + return image_proj_model + + def get_image_embeds(self, face_embed, clip_embeds): + embeds = self.image_proj_model(face_embed, clip_embeds) + return embeds + +class To_KV(nn.Module): + def __init__(self, state_dict): + super().__init__() + + self.to_kvs = nn.ModuleDict() + for key, value in state_dict.items(): + self.to_kvs[key.replace(".weight", "").replace(".", "_")] = nn.Linear(value.shape[1], value.shape[0], bias=False) + self.to_kvs[key.replace(".weight", "").replace(".", "_")].weight.data = value + +def tensor_to_image(tensor): + image = tensor.mul(255).clamp(0, 255).byte().cpu() + image = image[..., [2, 1, 0]].numpy() + return image + +def image_to_tensor(image): + tensor = torch.clamp(torch.from_numpy(image).float() / 255., 0, 1) + tensor = tensor[..., [2, 1, 0]] + return tensor + +def tensor_to_size(source, dest_size): + if isinstance(dest_size, torch.Tensor): + dest_size = dest_size.shape[0] + source_size = source.shape[0] + + if source_size < dest_size: + shape = [dest_size - source_size] + [1]*(source.dim()-1) + source = torch.cat((source, source[-1:].repeat(shape)), dim=0) + elif source_size > dest_size: + source = source[:dest_size] + + return source + +def set_model_patch_replace(model, patch_kwargs, key): + to = model.model_options["transformer_options"].copy() + if "patches_replace" not in to: + to["patches_replace"] = {} + else: + to["patches_replace"] = to["patches_replace"].copy() + + if "attn2" not in to["patches_replace"]: + to["patches_replace"]["attn2"] = {} + else: + to["patches_replace"]["attn2"] = to["patches_replace"]["attn2"].copy() + + if key not in to["patches_replace"]["attn2"]: + to["patches_replace"]["attn2"][key] = Attn2Replace(pulid_attention, **patch_kwargs) + model.model_options["transformer_options"] = to + else: + to["patches_replace"]["attn2"][key].add(pulid_attention, **patch_kwargs) + +class Attn2Replace: + def __init__(self, callback=None, **kwargs): + self.callback = [callback] + self.kwargs = [kwargs] + + def add(self, callback, **kwargs): + self.callback.append(callback) + self.kwargs.append(kwargs) + + for key, value in kwargs.items(): + setattr(self, key, value) + + def __call__(self, q, k, v, extra_options): + dtype = q.dtype + out = optimized_attention(q, k, v, extra_options["n_heads"]) + sigma = extra_options["sigmas"].detach().cpu()[0].item() if 'sigmas' in extra_options else 999999999.9 + + for i, callback in enumerate(self.callback): + if sigma <= self.kwargs[i]["sigma_start"] and sigma >= self.kwargs[i]["sigma_end"]: + out = out + callback(out, q, k, v, extra_options, **self.kwargs[i]) + + return out.to(dtype=dtype) + +def pulid_attention(out, q, k, v, extra_options, module_key='', pulid=None, cond=None, uncond=None, weight=1.0, ortho=False, ortho_v2=False, mask=None, **kwargs): + k_key = module_key + "_to_k_ip" + v_key = module_key + "_to_v_ip" + + dtype = q.dtype + seq_len = q.shape[1] + cond_or_uncond = extra_options["cond_or_uncond"] + b = q.shape[0] + batch_prompt = b // len(cond_or_uncond) + _, _, oh, ow = extra_options["original_shape"] + + #conds = torch.cat([uncond.repeat(batch_prompt, 1, 1), cond.repeat(batch_prompt, 1, 1)], dim=0) + #zero_tensor = torch.zeros((conds.size(0), num_zero, conds.size(-1)), dtype=conds.dtype, device=conds.device) + #conds = torch.cat([conds, zero_tensor], dim=1) + #ip_k = pulid.ip_layers.to_kvs[k_key](conds) + #ip_v = pulid.ip_layers.to_kvs[v_key](conds) + + k_cond = pulid.ip_layers.to_kvs[k_key](cond).repeat(batch_prompt, 1, 1) + k_uncond = pulid.ip_layers.to_kvs[k_key](uncond).repeat(batch_prompt, 1, 1) + v_cond = pulid.ip_layers.to_kvs[v_key](cond).repeat(batch_prompt, 1, 1) + v_uncond = pulid.ip_layers.to_kvs[v_key](uncond).repeat(batch_prompt, 1, 1) + ip_k = torch.cat([(k_cond, k_uncond)[i] for i in cond_or_uncond], dim=0) + ip_v = torch.cat([(v_cond, v_uncond)[i] for i in cond_or_uncond], dim=0) + + out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"]) + + if ortho: + out = out.to(dtype=torch.float32) + out_ip = out_ip.to(dtype=torch.float32) + projection = (torch.sum((out * out_ip), dim=-2, keepdim=True) / torch.sum((out * out), dim=-2, keepdim=True) * out) + orthogonal = out_ip - projection + out_ip = weight * orthogonal + elif ortho_v2: + out = out.to(dtype=torch.float32) + out_ip = out_ip.to(dtype=torch.float32) + attn_map = q @ ip_k.transpose(-2, -1) + attn_mean = attn_map.softmax(dim=-1).mean(dim=1, keepdim=True) + attn_mean = attn_mean[:, :, :5].sum(dim=-1, keepdim=True) + projection = (torch.sum((out * out_ip), dim=-2, keepdim=True) / torch.sum((out * out), dim=-2, keepdim=True) * out) + orthogonal = out_ip + (attn_mean - 1) * projection + out_ip = weight * orthogonal + else: + out_ip = out_ip * weight + + if mask is not None: + mask_h = oh / math.sqrt(oh * ow / seq_len) + mask_h = int(mask_h) + int((seq_len % int(mask_h)) != 0) + mask_w = seq_len // mask_h + + mask = F.interpolate(mask.unsqueeze(1), size=(mask_h, mask_w), mode="bilinear").squeeze(1) + mask = tensor_to_size(mask, batch_prompt) + + mask = mask.repeat(len(cond_or_uncond), 1, 1) + mask = mask.view(mask.shape[0], -1, 1).repeat(1, 1, out.shape[2]) + + # covers cases where extreme aspect ratios can cause the mask to have a wrong size + mask_len = mask_h * mask_w + if mask_len < seq_len: + pad_len = seq_len - mask_len + pad1 = pad_len // 2 + pad2 = pad_len - pad1 + mask = F.pad(mask, (0, 0, pad1, pad2), value=0.0) + elif mask_len > seq_len: + crop_start = (mask_len - seq_len) // 2 + mask = mask[:, crop_start:crop_start+seq_len, :] + + out_ip = out_ip * mask + + return out_ip.to(dtype=dtype) + +def to_gray(img): + x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3] + x = x.repeat(1, 3, 1, 1) + return x + +""" +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + Nodes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +""" + +class PulidModelLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "pulid_file": (folder_paths.get_filename_list("pulid"), )}} + + RETURN_TYPES = ("PULID",) + FUNCTION = "load_model" + CATEGORY = "pulid" + + def load_model(self, pulid_file): + ckpt_path = folder_paths.get_full_path("pulid", pulid_file) + + model = comfy.utils.load_torch_file(ckpt_path, safe_load=True) + + if ckpt_path.lower().endswith(".safetensors"): + st_model = {"image_proj": {}, "ip_adapter": {}} + for key in model.keys(): + if key.startswith("image_proj."): + st_model["image_proj"][key.replace("image_proj.", "")] = model[key] + elif key.startswith("ip_adapter."): + st_model["ip_adapter"][key.replace("ip_adapter.", "")] = model[key] + model = st_model + + # Also initialize the model, takes longer to load but then it doesn't have to be done every time you change parameters in the apply node + model = PulidModel(model) + + return (model,) + +class PulidInsightFaceLoader: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "provider": (["CPU", "CUDA", "ROCM"], ), + }, + } + + RETURN_TYPES = ("FACEANALYSIS",) + FUNCTION = "load_insightface" + CATEGORY = "pulid" + + def load_insightface(self, provider): + model = FaceAnalysis(name="antelopev2", root=INSIGHTFACE_DIR, providers=[provider + 'ExecutionProvider',]) # alternative to buffalo_l + model.prepare(ctx_id=0, det_size=(640, 640)) + + return (model,) + +class PulidEvaClipLoader: + @classmethod + def INPUT_TYPES(s): + return { + "required": {}, + } + + RETURN_TYPES = ("EVA_CLIP",) + FUNCTION = "load_eva_clip" + CATEGORY = "pulid" + + def load_eva_clip(self): + from .eva_clip.factory import create_model_and_transforms + + model, _, _ = create_model_and_transforms('EVA02-CLIP-L-14-336', 'eva_clip', force_custom_clip=True) + + model = model.visual + + eva_transform_mean = getattr(model, 'image_mean', OPENAI_DATASET_MEAN) + eva_transform_std = getattr(model, 'image_std', OPENAI_DATASET_STD) + if not isinstance(eva_transform_mean, (list, tuple)): + model["image_mean"] = (eva_transform_mean,) * 3 + if not isinstance(eva_transform_std, (list, tuple)): + model["image_std"] = (eva_transform_std,) * 3 + + return (model,) + + +class ApplyPulid: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL", ), + "pulid": ("PULID", ), + "eva_clip": ("EVA_CLIP", ), + "face_analysis": ("FACEANALYSIS", ), + "image": ("IMAGE", ), + "method": (["fidelity", "style", "neutral"],), + "weight": ("FLOAT", {"default": 1.0, "min": -1.0, "max": 5.0, "step": 0.05 }), + "start_at": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001 }), + "end_at": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001 }), + }, + "optional": { + "attn_mask": ("MASK", ), + }, + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "apply_pulid" + CATEGORY = "pulid" + + def apply_pulid(self, model, pulid, eva_clip, face_analysis, image, weight, start_at, end_at, method=None, noise=0.0, fidelity=None, projection=None, attn_mask=None): + work_model = model.clone() + + device = comfy.model_management.get_torch_device() + dtype = comfy.model_management.unet_dtype() + if dtype not in [torch.float32, torch.float16, torch.bfloat16]: + dtype = torch.float16 if comfy.model_management.should_use_fp16() else torch.float32 + + eva_clip.to(device, dtype=dtype) + pulid_model = pulid.to(device, dtype=dtype) + + if attn_mask is not None: + if attn_mask.dim() > 3: + attn_mask = attn_mask.squeeze(-1) + elif attn_mask.dim() < 3: + attn_mask = attn_mask.unsqueeze(0) + attn_mask = attn_mask.to(device, dtype=dtype) + + if method == "fidelity" or projection == "ortho_v2": + num_zero = 8 + ortho = False + ortho_v2 = True + elif method == "style" or projection == "ortho": + num_zero = 16 + ortho = True + ortho_v2 = False + else: + num_zero = 0 + ortho = False + ortho_v2 = False + + if fidelity is not None: + num_zero = fidelity + + #face_analysis.det_model.input_size = (640,640) + image = tensor_to_image(image) + + face_helper = FaceRestoreHelper( + upscale_factor=1, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + device=device, + ) + + face_helper.face_parse = None + face_helper.face_parse = init_parsing_model(model_name='bisenet', device=device) + + bg_label = [0, 16, 18, 7, 8, 9, 14, 15] + cond = [] + uncond = [] + + for i in range(image.shape[0]): + # get insightface embeddings + iface_embeds = None + for size in [(size, size) for size in range(640, 256, -64)]: + face_analysis.det_model.input_size = size + face = face_analysis.get(image[i]) + if face: + face = sorted(face, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse=True)[-1] + iface_embeds = torch.from_numpy(face.embedding).unsqueeze(0).to(device, dtype=dtype) + break + else: + raise Exception('insightface: No face detected.') + + # get eva_clip embeddings + face_helper.clean_all() + face_helper.read_image(image[i]) + face_helper.get_face_landmarks_5(only_center_face=True) + face_helper.align_warp_face() + + if len(face_helper.cropped_faces) == 0: + raise Exception('facexlib: No face detected.') + + face = face_helper.cropped_faces[0] + face = image_to_tensor(face).unsqueeze(0).permute(0,3,1,2).to(device) + parsing_out = face_helper.face_parse(T.functional.normalize(face, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0] + parsing_out = parsing_out.argmax(dim=1, keepdim=True) + bg = sum(parsing_out == i for i in bg_label).bool() + white_image = torch.ones_like(face) + face_features_image = torch.where(bg, white_image, to_gray(face)) + # apparently MPS only supports NEAREST interpolation? + face_features_image = T.functional.resize(face_features_image, eva_clip.image_size, T.InterpolationMode.BICUBIC if 'cuda' in device.type else T.InterpolationMode.NEAREST).to(device, dtype=dtype) + face_features_image = T.functional.normalize(face_features_image, eva_clip.image_mean, eva_clip.image_std) + + id_cond_vit, id_vit_hidden = eva_clip(face_features_image, return_all_features=False, return_hidden=True, shuffle=False) + id_cond_vit = id_cond_vit.to(device, dtype=dtype) + for idx in range(len(id_vit_hidden)): + id_vit_hidden[idx] = id_vit_hidden[idx].to(device, dtype=dtype) + + id_cond_vit = torch.div(id_cond_vit, torch.norm(id_cond_vit, 2, 1, True)) + + # combine embeddings + id_cond = torch.cat([iface_embeds, id_cond_vit], dim=-1) + if noise == 0: + id_uncond = torch.zeros_like(id_cond) + else: + id_uncond = torch.rand_like(id_cond) * noise + id_vit_hidden_uncond = [] + for idx in range(len(id_vit_hidden)): + if noise == 0: + id_vit_hidden_uncond.append(torch.zeros_like(id_vit_hidden[idx])) + else: + id_vit_hidden_uncond.append(torch.rand_like(id_vit_hidden[idx]) * noise) + + cond.append(pulid_model.get_image_embeds(id_cond, id_vit_hidden)) + uncond.append(pulid_model.get_image_embeds(id_uncond, id_vit_hidden_uncond)) + + # average embeddings + cond = torch.cat(cond).to(device, dtype=dtype) + uncond = torch.cat(uncond).to(device, dtype=dtype) + if cond.shape[0] > 1: + cond = torch.mean(cond, dim=0, keepdim=True) + uncond = torch.mean(uncond, dim=0, keepdim=True) + + if num_zero > 0: + if noise == 0: + zero_tensor = torch.zeros((cond.size(0), num_zero, cond.size(-1)), dtype=dtype, device=device) + else: + zero_tensor = torch.rand((cond.size(0), num_zero, cond.size(-1)), dtype=dtype, device=device) * noise + cond = torch.cat([cond, zero_tensor], dim=1) + uncond = torch.cat([uncond, zero_tensor], dim=1) + + sigma_start = work_model.get_model_object("model_sampling").percent_to_sigma(start_at) + sigma_end = work_model.get_model_object("model_sampling").percent_to_sigma(end_at) + + patch_kwargs = { + "pulid": pulid_model, + "weight": weight, + "cond": cond, + "uncond": uncond, + "sigma_start": sigma_start, + "sigma_end": sigma_end, + "ortho": ortho, + "ortho_v2": ortho_v2, + "mask": attn_mask, + } + + number = 0 + for id in [4,5,7,8]: # id of input_blocks that have cross attention + block_indices = range(2) if id in [4, 5] else range(10) # transformer_depth + for index in block_indices: + patch_kwargs["module_key"] = str(number*2+1) + set_model_patch_replace(work_model, patch_kwargs, ("input", id, index)) + number += 1 + for id in range(6): # id of output_blocks that have cross attention + block_indices = range(2) if id in [3, 4, 5] else range(10) # transformer_depth + for index in block_indices: + patch_kwargs["module_key"] = str(number*2+1) + set_model_patch_replace(work_model, patch_kwargs, ("output", id, index)) + number += 1 + for index in range(10): + patch_kwargs["module_key"] = str(number*2+1) + set_model_patch_replace(work_model, patch_kwargs, ("middle", 0, index)) + number += 1 + + return (work_model,) + +class ApplyPulidAdvanced(ApplyPulid): + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL", ), + "pulid": ("PULID", ), + "eva_clip": ("EVA_CLIP", ), + "face_analysis": ("FACEANALYSIS", ), + "image": ("IMAGE", ), + "weight": ("FLOAT", {"default": 1.0, "min": -1.0, "max": 5.0, "step": 0.05 }), + "projection": (["ortho_v2", "ortho", "none"],), + "fidelity": ("INT", {"default": 8, "min": 0, "max": 32, "step": 1 }), + "noise": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.1 }), + "start_at": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001 }), + "end_at": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001 }), + }, + "optional": { + "attn_mask": ("MASK", ), + }, + } + +NODE_CLASS_MAPPINGS = { + "PulidModelLoader": PulidModelLoader, + "PulidInsightFaceLoader": PulidInsightFaceLoader, + "PulidEvaClipLoader": PulidEvaClipLoader, + "ApplyPulid": ApplyPulid, + "ApplyPulidAdvanced": ApplyPulidAdvanced, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "PulidModelLoader": "Load PuLID Model", + "PulidInsightFaceLoader": "Load InsightFace (PuLID)", + "PulidEvaClipLoader": "Load Eva Clip (PuLID)", + "ApplyPulid": "Apply PuLID", + "ApplyPulidAdvanced": "Apply PuLID Advanced", +} \ No newline at end of file diff --git a/PuLID_ComfyUI/pyproject.toml b/PuLID_ComfyUI/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..43833a1c751a946157b2a46e536b16ab6481bd00 --- /dev/null +++ b/PuLID_ComfyUI/pyproject.toml @@ -0,0 +1,15 @@ +[project] +name = "pulid_comfyui" +description = "PuLID ComfyUI native implementation." +version = "1.0.0" +license = "LICENSE" +dependencies = ["facexlib", "insightface", "onnxruntime", "onnxruntime-gpu", "ftfy", "timm"] + +[project.urls] +Repository = "https://github.com/cubiq/PuLID_ComfyUI" +# Used by Comfy Registry https://comfyregistry.org + +[tool.comfy] +PublisherId = "matteo" +DisplayName = "PuLID_ComfyUI" +Icon = "" diff --git a/PuLID_ComfyUI/requirements.txt b/PuLID_ComfyUI/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ed8eed5ca75ad5c0745f1db6297619681032ab07 --- /dev/null +++ b/PuLID_ComfyUI/requirements.txt @@ -0,0 +1,6 @@ +facexlib +insightface +onnxruntime +onnxruntime-gpu +ftfy +timm diff --git a/example_node.py.example b/example_node.py.example new file mode 100644 index 0000000000000000000000000000000000000000..29ab2aa72319354b147b7dd79e1c3179e54d3d06 --- /dev/null +++ b/example_node.py.example @@ -0,0 +1,155 @@ +class Example: + """ + A example node + + Class methods + ------------- + INPUT_TYPES (dict): + Tell the main program input parameters of nodes. + IS_CHANGED: + optional method to control when the node is re executed. + + Attributes + ---------- + RETURN_TYPES (`tuple`): + The type of each element in the output tuple. + RETURN_NAMES (`tuple`): + Optional: The name of each output in the output tuple. + FUNCTION (`str`): + The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute() + OUTPUT_NODE ([`bool`]): + If this node is an output node that outputs a result/image from the graph. The SaveImage node is an example. + The backend iterates on these output nodes and tries to execute all their parents if their parent graph is properly connected. + Assumed to be False if not present. + CATEGORY (`str`): + The category the node should appear in the UI. + DEPRECATED (`bool`): + Indicates whether the node is deprecated. Deprecated nodes are hidden by default in the UI, but remain + functional in existing workflows that use them. + EXPERIMENTAL (`bool`): + Indicates whether the node is experimental. Experimental nodes are marked as such in the UI and may be subject to + significant changes or removal in future versions. Use with caution in production workflows. + execute(s) -> tuple || None: + The entry point method. The name of this method must be the same as the value of property `FUNCTION`. + For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`. + """ + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + """ + Return a dictionary which contains config for all input fields. + Some types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT". + Input types "INT", "STRING" or "FLOAT" are special values for fields on the node. + The type can be a list for selection. + + Returns: `dict`: + - Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required` + - Value input_fields (`dict`): Contains input fields config: + * Key field_name (`string`): Name of a entry-point method's argument + * Value field_config (`tuple`): + + First value is a string indicate the type of field or a list for selection. + + Second value is a config for type "INT", "STRING" or "FLOAT". + """ + return { + "required": { + "image": ("IMAGE",), + "int_field": ("INT", { + "default": 0, + "min": 0, #Minimum value + "max": 4096, #Maximum value + "step": 64, #Slider's step + "display": "number", # Cosmetic only: display as "number" or "slider" + "lazy": True # Will only be evaluated if check_lazy_status requires it + }), + "float_field": ("FLOAT", { + "default": 1.0, + "min": 0.0, + "max": 10.0, + "step": 0.01, + "round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding. + "display": "number", + "lazy": True + }), + "print_to_screen": (["enable", "disable"],), + "string_field": ("STRING", { + "multiline": False, #True if you want the field to look like the one on the ClipTextEncode node + "default": "Hello World!", + "lazy": True + }), + }, + } + + RETURN_TYPES = ("IMAGE",) + #RETURN_NAMES = ("image_output_name",) + + FUNCTION = "test" + + #OUTPUT_NODE = False + + CATEGORY = "Example" + + def check_lazy_status(self, image, string_field, int_field, float_field, print_to_screen): + """ + Return a list of input names that need to be evaluated. + + This function will be called if there are any lazy inputs which have not yet been + evaluated. As long as you return at least one field which has not yet been evaluated + (and more exist), this function will be called again once the value of the requested + field is available. + + Any evaluated inputs will be passed as arguments to this function. Any unevaluated + inputs will have the value None. + """ + if print_to_screen == "enable": + return ["int_field", "float_field", "string_field"] + else: + return [] + + def test(self, image, string_field, int_field, float_field, print_to_screen): + if print_to_screen == "enable": + print(f"""Your input contains: + string_field aka input text: {string_field} + int_field: {int_field} + float_field: {float_field} + """) + #do some processing on the image, in this example I just invert it + image = 1.0 - image + return (image,) + + """ + The node will always be re executed if any of the inputs change but + this method can be used to force the node to execute again even when the inputs don't change. + You can make this node return a number or a string. This value will be compared to the one returned the last time the node was + executed, if it is different the node will be executed again. + This method is used in the core repo for the LoadImage node where they return the image hash as a string, if the image hash + changes between executions the LoadImage node is executed again. + """ + #@classmethod + #def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen): + # return "" + +# Set the web directory, any .js file in that directory will be loaded by the frontend as a frontend extension +# WEB_DIRECTORY = "./somejs" + + +# Add custom API routes, using router +from aiohttp import web +from server import PromptServer + +@PromptServer.instance.routes.get("/hello") +async def get_hello(request): + return web.json_response("hello") + + +# A dictionary that contains all nodes you want to export with their names +# NOTE: names should be globally unique +NODE_CLASS_MAPPINGS = { + "Example": Example +} + +# A dictionary that contains the friendly/humanly readable titles for the nodes +NODE_DISPLAY_NAME_MAPPINGS = { + "Example": "Example Node" +} diff --git a/rgthree-comfy/LICENSE b/rgthree-comfy/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..a05fac3464aff37a0d33b4ddcdad0cfd772fb318 --- /dev/null +++ b/rgthree-comfy/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Regis Gaughan, III (rgthree) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/rgthree-comfy/README.md b/rgthree-comfy/README.md new file mode 100644 index 0000000000000000000000000000000000000000..46bfcf89270a950e0a481393e0e2d33c446b31de --- /dev/null +++ b/rgthree-comfy/README.md @@ -0,0 +1,411 @@ +

+ rgthree-comfy +
+ Making ComfyUI more comfortable! +
+

+

+ The Nodes   |   Improvements & Features   |   Link Fixer +

+
+ +A collection of nodes and improvements created while messing around with ComfyUI. I made them for myself to make my workflow cleaner, easier, and faster. You're welcome to try them out. But remember, I made them for my own use cases :) + +![Context Node](./docs/rgthree_advanced.png) + +# Get Started + +## Install + +1. Install the great [ComfyUi](https://github.com/comfyanonymous/ComfyUI). +2. Clone this repo into `custom_modules`: + ``` + cd ComfyUI/custom_nodes + git clone https://github.com/rgthree/rgthree-comfy.git + ``` +3. Start up ComfyUI. + +## Settings + +You can configure certain aspect of rgthree-comfy. For instance, perhaps a future ComfyUI change breaks rgthree-comfy, or you already have another extension that does something similar and you want to turn it off for rgthree-comfy. + +You can get to rgthree-settings by right-clicking on the empty part of the graph, and selecting `rgthree-comfy > Settings (rgthree-comfy)` or by clicking the `rgthree-comfy settings` in the ComfyUI settings dialog. + +_(Note, settings are stored in an `rgthree_config.json` in the `rgthree-comfy` directory. There are other advanced settings that can only be configured there; You can copy default settings from `rgthree_config.json.default` before `rgthree_config.json` before modifying)_. + +
+ +# ✴️ The Nodes + +Note, you can right-click on a bunch of the rgthree-comfy nodes and select `🛟 Node Help` menu item for in-app help when available. + +## Seed +> An intuitive seed control node for ComfyUI that works very much like Automatic1111's seed control. +>
+> ℹ️ See More Information +> +> - Set the seed value to "-1" to use a random seed every time +> - Set any other number in there to use as a static/fixed seed +> - Quick actions to randomize, or (re-)use the last queued seed. +> - Images metadata will store the seed value _(so dragging an image in, will have the seed field already fixed to its seed)_. +> - _Secret Features_: You can manually set the seed value to "-2" or "-3" to increment or decrement the last seed value. If there was not last seed value, it will randomly use on first. +> +> ![Router Node](./docs/rgthree_seed.png) +>
+ + +## Reroute +> Keep your workflow neat with this much improved Reroute node with, like, actual rerouting with multiple directions and sizes. +>
+> ℹ️ More Information +> +> - Use the right-click context menu to change the width, height and connection layout +> - Also toggle resizability (min size is 40x43 if resizing though), and title/type display. +> +> ![Router Node](./docs/rgthree_router.png) +>
+ +## Bookmark (🔖) +> Place the bookmark node anywhere on screen to quickly navigate to that with a shortcut key. +>
+> ℹ️ See More Information +> +> - Define the `shortcut_key` to press to go right to that bookmark node, anchored in the top left. +> - You can also define the zoom level as well! +> - Pro tip: `shortcut_key` can be multiple keys. For instance "alt + shift + !" would require +> pressing the alt key, the shift key, and the "!" (as in the "1" key, but with shift pressed) +> in order to trigger. +>
+ + +## Context / Context Big +> Pass along in general flow properties, and merge in new data. Similar to some other node suites "pipes" but easier merging, is more easily interoperable with standard nodes by both combining and exploding all in a single node. +>
+> ℹ️ More Information +> +> - Context and Context Big are backwards compatible with each other. That is, an input connected to a Context Big will be passed through the CONTEXT outputs through normal Context nodes and available as an output on either (or, Context Big if the output is only on that node, like "steps"). +> - Pro Tip: When dragging a Context output over a nother node, hold down "ctrl" and release to automatically connect the other Context outputs to the hovered node. +> - Pro Tip: You can change between Context and Context Big nodes from the menu. +> +> ![Context Node](./docs/rgthree_context.png) +>
+ +## Image Comparer +> The Image Comparer node compares two images on top of each other. +>
+> ℹ️ More Information +> +> - **Note:** The right-click menu may show image options (Open Image, Save Image, etc.) which will correspond to the first image (image_a) if clicked on the left-half of the node, or the second image if on the right half of the node. +> - **Inputs:** +> - `image_a` _Required._ The first image to use to compare. If image_b is not supplied and image_a is a batch, the comparer will use the first two images of image_a. +> - `image_b` _Optional._ The second image to use to compare. Optional only if image_a is a batch with two images. +> - **Properties:** You can change the following properties (by right-clicking on the node, and select "Properties" or "Properties Panel" from the menu): +> - `comparer_mode` - Choose between "Slide" and "Click". Defaults to "Slide". + + +## Image Inset Crop +> The node that lets you crop an input image by either pixel value, or percentage value. + + +## Display Any +> Displays most any piece of text data from the backend _after execution_. + +## Power Lora Loader +> A super-simply Lora Loader node that can load multiple Loras at once, and quick toggle each, all in an ultra-condensed node. +>
+> ℹ️ More Information +> +> - Add as many Lora's as you would like by clicking the "+ Add Lora" button. There's no real limit! +> - Right-click on a Lora widget for special options to move the lora up or down +> _(no affect on image, just presentation)_, toggle it on/off, or delete the row all together. +> - from the properties, change the `Show Strengths` to choose between showing a single, simple +> strength value (which will be used for both model and clip), or a more advanced view with +> both model and clip strengths being modifiable. +>
+ + +## ~~Lora Loader Stack~~ +> _**Deprecated.** Used the `Power Lora Loader` instead._ +> +> A simplified Lora Loader stack. Much like other suites, but more interoperable with standard inputs/outputs. + + +## Power Prompt +> Power up your prompt and get drop downs for adding your embeddings, loras, and even have saved prompt snippets. +>
+> ℹ️ More Information +> +> - At the core, you can use Power Prompt almost as a String Primitive node with additional features of dropdowns for choosing your embeddings, and even loras, with no further processing. This will output just the raw `TEXT` to another node for any lora processing, CLIP Encoding, etc. +> - Connect a `CLIP` to the input to encode the text, with both the `CLIP` and `CONDITIONING` output right from the node. +> - Connect a `MODEL` to the input to parse and load any `` tags in the text automatically, without +> needing a separate Lora Loaders +>
+ +## Power Prompt - Simple +> Same as Power Prompt above, but without LORA support; made for a slightly cleaner negative prompt _(since negative prompts do not support loras)_. + +## SDXL Power Prompt - Positive +> The SDXL sibling to the Power Prompt above. It contains the text_g and text_l as separate text inputs, as well a couple more input slots necessary to ensure proper clipe encoding. Combine with + +## SDXL Power Prompt - Simple +> Like the non-SDXL `Power Prompt - Simple` node, this one is essentially the same as the SDXL Power Prompt but without lora support for either non-lora positive prompts or SDXL negative prompts _(since negative prompts do not support loras)_. + +## SDXL Config +> Just some configuration fields for SDXL prompting. Honestly, could be used for non SDXL too. + +## Context Switch / Context Switch Big +> A powerful node to branch your workflow. Works by choosing the first Context input that is not null/empty. +>
+> ℹ️ More Information +> +> - Pass in several context nodes and the Context Switch will automatically choose the first non-null context to continue onward with. +> - Wondering how to toggle contexts to null? Use in conjuction with the **Fast Muter** or **Fast Groups Muter** +> +>
+ +## Any Switch +> A powerful node to similar to the Context Switch above, that chooses the first input that is not null/empty. +>
+> ℹ️ More Information +> +> - Pass in several inmputs of the same type and the Any Switch will automatically choose the first non-null value to continue onward with. +> - Wondering how to toggle contexts to null? Use in conjuction with the **Fast Muter** or **Fast Groups Muter** +> +>
+ + +## Fast Groups Muter +> The Fast Groups Muter is an input-less node that automatically collects all groups in your current workflow and allows you to quickly mute and unmute all nodes within the group. +>
+> ℹ️ More Information +> +> - Groups will automatically be shown, though you can filter, sort and more from the **node Properties** _(by right-clicking on the node, and select "Properties" or "Properties Panel" from the menu)_. Properties include: +> - `matchColors` - Only add groups that match the provided colors. Can be ComfyUI colors (red, pale_blue) or hex codes (#a4d399). Multiple can be added, comma delimited. +> - `matchTitle` - Filter the list of toggles by title match (string match, or regular expression). +> - `showNav` - Add / remove a quick navigation arrow to take you to the group. (default: true) +> - `sort` - Sort the toggles' order by "alphanumeric", graph "position", or "custom alphabet". (default: "position") +> - `customSortAlphabet` - When the sort property is "custom alphabet" you can define the alphabet to use here, which will match the beginning of each group name and sort against it. If group titles do not match any custom alphabet entry, then they will be put after groups that do, ordered alphanumerically. +> +> This can be a list of single characters, like "zyxw..." or comma delimited strings for more control, like "sdxl,pro,sd,n,p". +> +> Note, when two group title match the same custom alphabet entry, the normal alphanumeric alphabet breaks the tie. For instance, a custom alphabet of "e,s,d" will order groups names like "SDXL, SEGS, Detailer" eventhough the custom alphabet has an "e" before "d" (where one may expect "SE" to be before "SD"). +> +> To have "SEGS" appear before "SDXL" you can use longer strings. For instance, the custom alphabet value of "se,s,f" would work here. +> - `toggleRestriction` - Optionally, attempt to restrict the number of widgets that can be enabled to a maximum of one, or always one. +> +> _Note: If using "max one" or "always one" then this is only enforced when clicking a toggle on this node; if nodes within groups are changed outside of the initial toggle click, then these restriction will not be enforced, and could result in a state where more than one toggle is enabled. This could also happen if nodes are overlapped with multiple groups._ +>
+ +## Fast Groups Bypasser +> _Same as **Fast Groups Muter** above, but sets the connected nodes to "Bypass" instead of "Mute"_ + + +## Fast Muter +> A powerful 'control panel' node to quickly toggle connected nodes allowing them to quickly be muted or enabled +>
+> ℹ️ More Information +> +> - Add a collection of all connected nodes allowing a single-spot as a "dashboard" to quickly enable and disable nodes. Two distinct nodes; one for "Muting" connected nodes, and one for "Bypassing" connected nodes. +>
+ + +## Fast Bypasser +> Same as Fast Muter but sets the connected nodes to "Bypass" + +## Fast Actions Button +> Oh boy, this node allows you to semi-automate connected nodes and/or ConfyUI. +>
+> ℹ️ More Information +> +> - Connect nodes and, at the least, mute, bypass or enable them when the button is pressed. +> - Certain nodes expose additional actions. For instance, the `Seed` node you can set `Randomize Each Time` or `Use Last Queued Seed` when the button is pressed. +> - Also, from the node properties, set a shortcut key to toggle the button actions, without needing a click! +>
+ + +## Node Collector +> Used to cleanup noodles, this will accept any number of input nodes and passes it along to another node. +> +> ⚠️ *Currently, this should really only be connected to **Fast Muter**, **Fast Bypasser**, or **Mute / Bypass Relay**.* + + +## Mute / Bypass Repeater +> A powerful node that will dispatch its Mute/Bypass/Active mode to all connected input nodes or, if in a group w/o any connected inputs, will dispatch its Mute/Bypass/Active mode to all nodes in that group. +>
+> ℹ️ More Information +> +> - 💡 Pro Tip #1: Connect this node's output to a **Fast Muter** or **Fast Bypasser** to have a single toggle there that can mute/bypass/enable many nodes with one click. +> +> - 💡 Pro Tip #2: Connect a **Mute / Bypass Relay** node to this node's inputs to have the relay automatically dispatch a mute/bypass/enable change to the repeater. +>
+ + +## Mute / Bypass Relay +> An advanced node that, when working with a **Mute / Bypass Repeater**, will relay its input nodes' +> modes (Mute, Bypass, or Active) to a connected repeater (which would then repeat that mode change +> to all of its inputs). +>
+> ℹ️ More Information +> +> - When all connected input nodes are muted, the relay will set a connected repeater to mute (by +> default). +> - When all connected input nodes are bypassed, the relay will set a connected repeater to +> bypass (by default). +> - When _any_ connected input nodes are active, the relay will set a connected repeater to +> active (by default). +> - **Note:** If no inputs are connected, the relay will set a connected repeater to its mode +> _when its own mode is changed_. **Note**, if any inputs are connected, then the above bullets +> will occur and the Relay's mode does not matter. +> - **Pro Tip:** You can change which signals get sent on the above in the `Properties`. +> For instance, you could configure an inverse relay which will send a MUTE when any of its +> inputs are active (instead of sending an ACTIVE signal), and send an ACTIVE signal when all +> of its inputs are muted (instead of sending a MUTE signal), etc. +>
+ + +## Random Unmuter +> An advanced node used to unmute one of its inputs randomly when the graph is queued (and, immediately mute it back). +>
+> ℹ️ More Information +> +> - **Note:** All input nodes MUST be muted to start; if not this node will not randomly unmute another. (This is powerful, as the generated image can be dragged in and the chosen input will already by unmuted and work w/o any further action.) +> - **Tip:** Connect a Repeater's output to this nodes input and place that Repeater on a group without any other inputs, and it will mute/unmute the entire group. +>
+ + +## Label +> A purely visual node, this allows you to add a floating label to your workflow. +>
+> ℹ️ More Information +> +> - The text shown is the "Title" of the node and you can adjust the the font size, font family, +> font color, text alignment as well as a background color, padding, and background border +> radius from the node's properties. You can double-click the node to open the properties +> panel. +> - **Pro Tip #1:** You can add multiline text from the properties panel _(because ComfyUI let's +> you shift + enter there, only)._ +> - **Pro Tip #2:** You can use ComfyUI's native "pin" option in the right-click menu to make the +> label stick to the workflow and clicks to "go through". You can right-click at any time to +> unpin. +> - **Pro Tip #3:** Color values are hexidecimal strings, like "#FFFFFF" for white, or "#660000" +> for dark red. You can supply a 7th & 8th value (or 5th if using shorthand) to create a +> transluscent color. For instance, "#FFFFFF88" is semi-transparent white. +>
+ + +# Advanced Techniques + +## First, a word on muting + +A lot of the power of these nodes comes from *Muting*. Muting is the basis of correctly implementing multiple paths for a workflow utlizing the Context Switch node. + +While other extensions may provide switches, they often get it wrong causing your workflow to do more work than is needed. While other switches may have a selector to choose which input to pass along, they don't stop the execution of the other inputs, which will result in wasted work. Instead, Context Switch works by choosing the first non-empty context to pass along and correctly Muting is one way to make a previous node empty, and causes no extra work to be done when set up correctly. + +### To understand muting, is to understand the graph flow + +Muting, and therefore using Switches, can often confuse people at first because it _feels_ like muting a node, or using a switch, should be able to stop or direct the _forward_ flow of the graph. However, this is not the case and, in fact, the graph actually starts working backwards. + +If you have a workflow that has a path like `... > Context > KSampler > VAE Decode > Save Image` it may initially _feel_ like you should be able to mute that first Context node and the graph would stop there when moving forward and skip the rest of that workflow. + +But you'll quickly find that will cause an error, becase the graph doesn't actually move forward. When a workflow is processed, it _first moves backwards_ starting at each "Output Node" (Preview Image, Save Image, even "Display String" etc.) and then walking backwards to all possible paths to get there. + +So, with that `... > Context > KSampler > VAE Decode > Save Image` example from above, we actually want to mute the `Save Image` node to stop this path. Once we do, since the output node is gone, none of these nodes will be run. + +Let's take a look at an example. + +### A powerful combination: Using Context, Context Switch, & Fast Muter + +![Context Node](./docs/rgthree_advanced.png) + +1. Using the **Context Switch** (aqua colored in screenshot) feed context inputs in order of preference. In the workflow above, the `Upscale Out` context is first so, if that one is enabled, it will be chosen for the output. If not, the second input slot which comes from the context rerouted from above (before the Upscaler booth) will be chosen. + + - Notice the `Upscale Preview` is _after_ the `Upscale Out` context node, using the image from it instead of the image from the upscale `VAE Decoder`. This is on purpose so, when we disable the `Upscale Out` context, none of the Upscaler nodes will run, saving precious GPU cycles. If we had the preview hooked up directly to the `VAE Decoder` the upscaler would always run to generate the preview, even if we had the `Upscale Out` context node disabled. + +2. We can now disable the `Upscale Out` context node by _muting_ it. Highlighting it and pressing `ctrl + m` will work. By doing so, it's output will be None, and it will not pass anthing onto the further nodes. In the diagram you can see the `Upscale Preview` is red, but that's OK; there are no actual errors to stop execution. + +3. Now, let's hook it up to the `Fast Muter` node. `The Fast Muter` node works as dashboard by adding quick toggles for any connected node (ignoring reroutes). In the diagram, we have both the `Upscaler Out` context node, and the `Save File` context node hooked up. So, we can quickly enable and disable those. + + - The workflow seen here would be a common one where we can generate a handful of base previews cheaply with a random seed, and then choose one to upscale and save to disk. + +4. Lastly, and optionally, you can see the `Node Collector`. Use it to clean up noodles if you want and connect it to the muter. You can connect anything to it, but doing so may break your workflow's execution. + +
+ +# ⚡ Improvements & Features + +rgthree-comfy adds several improvements, features, and optimizations to ComfyUI that are not directly tied to nodes. + +## Progress Bar +> A minimal progress bar that run alongs the top of the app window that shows the queue size, the current progress of the a prompt execution (within the same window), and the progress of multi-step nodes as well. +> +> You can remove/enable from rgthree-comfy settings, as well as configure the height/size. + + +## ComfyUI Recursive Optimization +> An optimization to ComfyUI's recursive execution. Because rgthree-comfy nodes make it easy to build larger, more complex workflows, I (and others) started to hit a wall of poor execution times. +>
+> ℹ️ More Information +> +> - Until [ComfyUI/issues/1502](https://github.com/comfyanonymous/ComfyUI/issues/1502) is resolved and/or [ComfyUI/pull/1503](https://github.com/comfyanonymous/ComfyUI/pull/1503) is pulled in, then know that you're benefiting from hundreds of millions of saved cycles each run. +> +> - Specifically, for a rather complex test workflow, the patch reduces iterations of `recursive_will_execute` from 113,292,566 to just 135 (and 116.32 seconds to 69.84 seconds on my machine) on a fresh queue, and reduces recursive calls of `recursive_output_delete_if_changed` from 250,496,808 to 142 (and 158.13 seconds to 0.0 seconds on my machine). +> +> - ⚠️ *However,* there is a chance ComfyUI changes something in/around the code I patched which could break. If that's the case, you should disable the optimization from rgthree-comfy settings. +> +>
+ + +## "Queue Selected Output Nodes" in right-click menu +> Sometimes you want to just queue one or two paths to specific output node(s) without executing the entire workflow. Well, now you can do just that by right-clicking on an output node and selecting `Queue Selected Output Nodes (rgthree)`. +> +>
+> ℹ️ More Information +> +> - Select the _output_ nodes you want to execute. +> +> - Note: Only output nodes are captured and traversed, not all selected nodes. So if you select an output AND a node from a different path, only the path connected to the output will be executed and not non-output nodes, even if they were selected. +> +> - Note: The whole workflow is serialized, and then we trim what we don't want for the backend. So things like all seed random/increment/decrement will run even if that node isn't being sent in the end, etc. +> +>
+ + +## Auto-Nest Subdirectories in long Combos +> _(Off by default while experimenting, turn on in rgthree-comfy settings)_. +> +> Automatically detect top-level subdirectories in long combo lists (like, Load Checkpoint) and break out into sub directories. + + +## Quick Mute/Bypass Toggles in Group Headers +> _(Off by default while experimenting, turn on in rgthree-comfy settings)_. +> +> Adds a mute and/or bypass toggle icons in the top-right of Group Headers for one-click toggling of groups you may be currently looking at. + + +## Import Individual Node Widgets (Drag & Drop) +> _(Off by default while experimenting, turn on in rgthree-comfy settings)_. +> +> Allows dragging and dropping an image/JSON workflow from a previous generation and overriding the same node's widgets +> (that match with the same id & type). This is useful if you have several generations using the same general workflow +> and would like to import just some data, like a previous generation's seed, or prompt, etc. + + + +## "Copy Image" in right-click menu +> Right clicking on a node that has an image should have a context-menu item of "Copy Image" will allow you to copy the image right to your clipboard +> +> 🎓 I believe this has graduated, with ComfyUI recently adding this setting too. You won't get two menu items; my code checks that there isn't already a "Copy Image" item there before adding it. + + +## Other/Smaller Fixes +- Fixed the width of ultra-wide node chooser on double click. +- Fixed z-indexes for textareas that would overlap above other elements, like Properties Panel, or @pythongosssss's image viewer. +- Check for bad links when loading a workflow and log to console, by default. _(See Link Fixer below)._ + +
+ +# 📄 Link Fixer + +If your workflows sometimes have missing connections, or even errors on load, start up ComfyUI and go to http://127.0.0.1:8188/rgthree/link_fixer which will allow you to drop in an image or workflow json file and check for and fix any bad links. + +You can also enable a link fixer check in the rgthree-comfy settings to give you an alert if you load a workflow with bad linking data to start. diff --git a/rgthree-comfy/__build__.py b/rgthree-comfy/__build__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7832e6943450a73fef5529cce79e1402c094bcd --- /dev/null +++ b/rgthree-comfy/__build__.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 + +import subprocess +import os +from shutil import rmtree, copytree, ignore_patterns +from glob import glob +import time +import re +import argparse + +from py.log import COLORS +from py.config import RGTHREE_CONFIG + +start = time.time() + +parser = argparse.ArgumentParser() +parser.add_argument("-t", "--with-tests", default=False, action="store_true") +parser.add_argument("-f", "--fix", default=False, action="store_true") +args = parser.parse_args() + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +DIR_SRC_WEB = os.path.abspath(f'{THIS_DIR}/src_web/') +DIR_WEB = os.path.abspath(f'{THIS_DIR}/web/') +DIR_WEB_COMFYUI = os.path.abspath(f'{DIR_WEB}/comfyui/') + + +def log_step(msg=None, status=None): + """ Logs a step keeping track of timing and initial msg. """ + global step_msg # pylint: disable=W0601 + global step_start # pylint: disable=W0601 + global step_warns # pylint: disable=W0601 + if msg: + tag = f'{COLORS["YELLOW"]}[ Notice ]' if status == 'Notice' else f'{COLORS["RESET"]}[Starting]' + step_msg = f'▻ {tag}{COLORS["RESET"]} {msg}...' + step_start = time.time() + step_warns = [] + print(step_msg, end="\r") + elif status: + if status != 'Error': + status = "Warn" if len(step_warns) > 0 else status + step_time = round(time.time() - step_start, 3) + if status == 'Error': + status_msg = f'{COLORS["RED"]}⤫ {status}{COLORS["RESET"]}' + elif status == 'Warn': + status_msg = f'{COLORS["YELLOW"]}! {status}{COLORS["RESET"]}' + else: + status_msg = f'{COLORS["BRIGHT_GREEN"]}🗸 {status}{COLORS["RESET"]}' + print(f'{step_msg.ljust(64, ".")} {status_msg} ({step_time}s)') + for warning in step_warns: + print(warning) + + +if args.fix: + tss = glob(os.path.join(DIR_SRC_WEB, "**", "*.ts"), recursive=True) + log_step(msg=f'Fixing {len(tss)} ts files') + for ts in tss: + with open(ts, 'r', encoding="utf-8") as f: + content = f.read() + # (\s*from\s*['"](?!.*[.]js['"]).*?)(['"];) in vscode. + content, n = re.subn(r'(\s*from [\'"](?!.*[.]js[\'"]).*?)([\'"];)', '\\1.js\\2', content) + if n > 0: + filename = os.path.basename(ts) + step_warns.append( + f' - {filename} has {n} import{"s" if n > 1 else ""} that do not end in ".js"') + with open(ts, 'w', encoding="utf-8") as f: + f.write(content) + log_step(status="Done") + +log_step(msg='Copying web directory') +rmtree(DIR_WEB) +copytree(DIR_SRC_WEB, DIR_WEB, ignore=ignore_patterns("typings*", "*.ts", "*.scss")) +log_step(status="Done") + +ts_version_result = subprocess.run(["node", "./node_modules/typescript/bin/tsc", "-v"], + capture_output=True, + text=True, + check=True) +ts_version = re.sub(r'^.*Version\s*([\d\.]+).*', 'v\\1', ts_version_result.stdout, flags=re.DOTALL) + +log_step(msg=f'TypeScript ({ts_version})') +checked = subprocess.run(["node", "./node_modules/typescript/bin/tsc"], check=True) +log_step(status="Done") + +if args.with_tests: + log_step(msg='Removing directories (KEEPING TESTING)', status="Notice") +else: + log_step(msg='Removing uneeded directories') + test_path = os.path.join(DIR_WEB, 'comfyui', 'tests') + if os.path.exists(test_path): + rmtree(test_path) + rmtree(os.path.join(DIR_WEB, 'comfyui', 'testing')) +# Always remove the dummy scripts_comfy directory +rmtree(os.path.join(DIR_WEB, 'scripts_comfy')) +log_step(status="Done") + +scsss = glob(os.path.join(DIR_SRC_WEB, "**", "*.scss"), recursive=True) +log_step(msg=f'SASS for {len(scsss)} files') +scsss = [i.replace(THIS_DIR, '.') for i in scsss] +cmds = ["node", "./node_modules/sass/sass"] +for scss in scsss: + out = scss.replace('src_web', 'web').replace('.scss', '.css') + cmds.append(f'{scss}:{out}') +cmds.append('--no-source-map') +checked = subprocess.run(cmds, check=True) +log_step(status="Done") + +# Handle the common directories. Because ComfyUI loads under /extensions/rgthree-comfy we can't +# easily share sources outside of the `DIR_WEB_COMFYUI` _and_ allow typescript to resolve them in +# src view, so we set the path in the tsconfig to map an import of "rgthree/common" to the +# "src_web/common" directory, but then need to rewrite the comfyui JS files to load from +# "../../rgthree/common" (which we map correctly in rgthree_server.py). +log_step(msg='Cleaning Imports') +js_files = glob(os.path.join(DIR_WEB, '**', '*.js'), recursive=True) +for file in js_files: + rel_path = file.replace(f'{DIR_WEB}/', "") + with open(file, 'r', encoding="utf-8") as f: + filedata = f.read() + num = rel_path.count(os.sep) + if rel_path.startswith('comfyui'): + filedata = re.sub(r'(from\s+["\'])rgthree/', f'\\1{"../" * (num + 1)}rgthree/', filedata) + filedata = re.sub(r'(from\s+["\'])scripts/', f'\\1{"../" * (num + 1)}scripts/', filedata) + else: + filedata = re.sub(r'(from\s+["\'])rgthree/', f'\\1{"../" * num}', filedata) + filedata = re.sub(r'(from\s+["\'])scripts/', f'\\1{"../" * (num + 1)}scripts/', filedata) + filedata, n = re.subn(r'(\s*from [\'"](?!.*[.]js[\'"]).*?)([\'"];)', '\\1.js\\2', filedata) + if n > 0: + filename = os.path.basename(file) + step_warns.append( + f' - {filename} has {n} import{"s" if n > 1 else ""} that do not end in ".js"') + with open(file, 'w', encoding="utf-8") as f: + f.write(filedata) +log_step(status="Done") + +print(f'Finished all in {round(time.time() - start, 3)}s') diff --git a/rgthree-comfy/__init__.py b/rgthree-comfy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aebf98da57e0f20c52eb4f98e279fbab8c4e3e3a --- /dev/null +++ b/rgthree-comfy/__init__.py @@ -0,0 +1,321 @@ +""" +@author: rgthree +@title: Comfy Nodes +@nickname: rgthree +@description: A bunch of nodes I created that I also find useful. +""" + +from glob import glob +import json +import os +import shutil +import re +import random + +import execution + +from .py.log import log +from .py.config import get_config_value +from .py.rgthree_server import * + +from .py.context import RgthreeContext +from .py.context_switch import RgthreeContextSwitch +from .py.context_switch_big import RgthreeContextSwitchBig +from .py.display_any import RgthreeDisplayAny, RgthreeDisplayInt +from .py.lora_stack import RgthreeLoraLoaderStack +from .py.seed import RgthreeSeed +from .py.sdxl_empty_latent_image import RgthreeSDXLEmptyLatentImage +from .py.power_prompt import RgthreePowerPrompt +from .py.power_prompt_simple import RgthreePowerPromptSimple +from .py.image_inset_crop import RgthreeImageInsetCrop +from .py.context_big import RgthreeBigContext +from .py.dynamic_context import RgthreeDynamicContext +from .py.dynamic_context_switch import RgthreeDynamicContextSwitch +from .py.ksampler_config import RgthreeKSamplerConfig +from .py.sdxl_power_prompt_postive import RgthreeSDXLPowerPromptPositive +from .py.sdxl_power_prompt_simple import RgthreeSDXLPowerPromptSimple +from .py.any_switch import RgthreeAnySwitch +from .py.context_merge import RgthreeContextMerge +from .py.context_merge_big import RgthreeContextMergeBig +from .py.image_comparer import RgthreeImageComparer +from .py.power_lora_loader import RgthreePowerLoraLoader + +NODE_CLASS_MAPPINGS = { + RgthreeBigContext.NAME: RgthreeBigContext, + RgthreeContext.NAME: RgthreeContext, + RgthreeContextSwitch.NAME: RgthreeContextSwitch, + RgthreeContextSwitchBig.NAME: RgthreeContextSwitchBig, + RgthreeContextMerge.NAME: RgthreeContextMerge, + RgthreeContextMergeBig.NAME: RgthreeContextMergeBig, + RgthreeDisplayInt.NAME: RgthreeDisplayInt, + RgthreeDisplayAny.NAME: RgthreeDisplayAny, + RgthreeLoraLoaderStack.NAME: RgthreeLoraLoaderStack, + RgthreeSeed.NAME: RgthreeSeed, + RgthreeImageInsetCrop.NAME: RgthreeImageInsetCrop, + RgthreePowerPrompt.NAME: RgthreePowerPrompt, + RgthreePowerPromptSimple.NAME: RgthreePowerPromptSimple, + RgthreeKSamplerConfig.NAME: RgthreeKSamplerConfig, + RgthreeSDXLEmptyLatentImage.NAME: RgthreeSDXLEmptyLatentImage, + RgthreeSDXLPowerPromptPositive.NAME: RgthreeSDXLPowerPromptPositive, + RgthreeSDXLPowerPromptSimple.NAME: RgthreeSDXLPowerPromptSimple, + RgthreeAnySwitch.NAME: RgthreeAnySwitch, + RgthreeImageComparer.NAME: RgthreeImageComparer, + RgthreePowerLoraLoader.NAME: RgthreePowerLoraLoader, +} + +if get_config_value('unreleased.dynamic_context.enabled') is True: + NODE_CLASS_MAPPINGS[RgthreeDynamicContext.NAME] = RgthreeDynamicContext + NODE_CLASS_MAPPINGS[RgthreeDynamicContextSwitch.NAME] = RgthreeDynamicContextSwitch + +# WEB_DIRECTORY is the comfyui nodes directory that ComfyUI will link and auto-load. +WEB_DIRECTORY = "./web/comfyui" + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +DIR_WEB = os.path.abspath(f'{THIS_DIR}/{WEB_DIRECTORY}') +DIR_PY = os.path.abspath(f'{THIS_DIR}/py') + +# remove old directories +OLD_DIRS = [ + os.path.abspath(f'{THIS_DIR}/../../web/extensions/rgthree'), + os.path.abspath(f'{THIS_DIR}/../../web/extensions/rgthree-comfy'), +] +for old_dir in OLD_DIRS: + if os.path.exists(old_dir): + shutil.rmtree(old_dir) + +__all__ = ['NODE_CLASS_MAPPINGS', 'WEB_DIRECTORY'] + +NOT_NODES = ['constants', 'log', 'utils', 'rgthree', 'rgthree_server', 'image_clipbaord', 'config'] + +nodes = [] +for file in glob(os.path.join(DIR_PY, '*.py')) + glob(os.path.join(DIR_WEB, '*.js')): + name = os.path.splitext(os.path.basename(file))[0] + if name in NOT_NODES or name in nodes: + continue + if name.startswith('_') or name.startswith('base') or 'utils' in name: + continue + nodes.append(name) + if name == 'display_any': + nodes.append('display_int') + +print() +adjs = ['exciting', 'extraordinary', 'epic', 'fantastic', 'magnificent'] +log(f'Loaded {len(nodes)} {random.choice(adjs)} nodes.', color='BRIGHT_GREEN') + +# Alright, I don't like doing this, but until https://github.com/comfyanonymous/ComfyUI/issues/1502 +# and/or https://github.com/comfyanonymous/ComfyUI/pull/1503 is pulled into ComfyUI, we need a way +# to optimize the recursion that happens on prompt eval. This is particularly important for +# rgthree nodes because workflows can contain many context nodes, but the problem would exist for +# other nodes' (like "pipe" nodes, efficieny nodes). With `Context Big` nodes being +# introduced, the number of input recursion that happens in these methods is exponential with a +# saving of 1000's of percentage points over. + +# We'll use this to check if we _can_ patch execution. Other work to change the execution may +# remove these methods, and we want to ensure people's apps do not break. +could_patch_execution = (hasattr(execution, 'recursive_output_delete_if_changed') and + hasattr(execution, 'recursive_will_execute') and + hasattr(execution.PromptExecutor, 'execute')) + +if get_config_value('features.patch_recursive_execution') is True: + if not could_patch_execution: + log("NOTE: Will NOT use rgthree's optimized recursive execution as ComfyUI has changed.", + color='YELLOW') + else: + log("Will use rgthree's optimized recursive execution.", color='BRIGHT_GREEN') + + +class RgthreePatchRecursiveExecute_Set_patch_recursive_execution_to_false_if_not_working: + """A fake 'list' that the caller for recursive_will_execute expects but we override such that + `len(inst)` will return the count number, and `inst[-1]` will return the unique_id. Since that + all the caller cares about, we can save several minutes and many MB of ram by simply counting + numbers instead of concatenating a list of millions (only to count it). However the caller + expects such a list, so we fake it with this. + + This mimics the enhancement from https://github.com/rgthree/ComfyUI/commit/50b3fb1 but without + modifying the execution.py + """ + + def __init__(self, unique_id): + self.unique_id = unique_id + self.count = 0 + + def add(self, value): + self.count += value + + def __getitem__(self, key): + """Returns the `unique_id` with '-1' since that's what the caller expects.""" + if key == -1: + return self.unique_id + # This one would future proof the proposed changes, in that case "0" is the count + if key == 0: + return self.count + else: + return -1 + + def __len__(self): + """Returns the "count" of the "list" as if we were building up a list instea of just + incrementing `count`. + """ + return self.count + + # The following (hopefully) future proofs if https://github.com/rgthree/ComfyUI/commit/50b3fb1 + # goes in, which changes from using `len` on a list, to sort directly (and, thus "<" and ">"). + def __gt__(self, other): + return self.count > other + + def __lt__(self, other): + return self.count < other + + def __str__(self): + return str(( + self.count, + self.unique_id, + )) + + +# Caches which will be cleared on each run +execution.rgthree_cache_recursive_output_delete_if_changed_output = {} +execution.rgthree_cache_recursive_will_execute = {} +execution.rgthree_is_currently_optimized = False + + +def rgthree_execute(self, *args, **kwargs): + """ A patch of ComfyUI's default execution for optimization (or un-optimization) via config.""" + if get_config_value('features.patch_recursive_execution') is True: + + if could_patch_execution: + log("Using rgthree's optimized recursive execution.", color='GREEN') + # When we execute, we'll reset our global cache here. + execution.rgthree_cache_recursive_output_delete_if_changed_output = {} + execution.rgthree_cache_recursive_will_execute = {} + + if not execution.rgthree_is_currently_optimized: + log("First run patching recursive_output_delete_if_changed and recursive_will_execute.", + color='GREEN', + msg_color='RESET') + log( + "Note: \33[0mIf execution seems broken due to forward ComfyUI changes, you can disable " + + "the optimization from rgthree settings in ComfyUI.", + color='YELLOW') + execution.rgthree_old_recursive_output_delete_if_changed = execution.recursive_output_delete_if_changed + execution.recursive_output_delete_if_changed = rgthree_recursive_output_delete_if_changed + + execution.rgthree_old_recursive_will_execute = execution.recursive_will_execute + execution.recursive_will_execute = rgthree_recursive_will_execute + execution.rgthree_is_currently_optimized = True + + elif execution.rgthree_is_currently_optimized: + log("Removing optimizations to recursive_output_delete_if_changed and recursive_will_execute.", + color='YELLOW', + msg_color='RESET') + log("You can enable optimization in the rgthree settings in ComfyUI.", color='CYAN') + execution.recursive_output_delete_if_changed = execution.rgthree_old_recursive_output_delete_if_changed + execution.recursive_will_execute = execution.rgthree_old_recursive_will_execute + execution.rgthree_is_currently_optimized = False + + # We always call the original execute, it's just whether we patch or unpacth first. + return self.rgthree_old_execute(*args, **kwargs) + + +# We always patch execute, so we can check if we want to do work. Up in rgthree_execute we will +# either patch or unpatch recursive_will_execute recursive_output_delete_if_changed at runtime when +# config changes. +execution.PromptExecutor.rgthree_old_execute = execution.PromptExecutor.execute +execution.PromptExecutor.execute = rgthree_execute + + +def rgthree_recursive_will_execute(prompt, outputs, current_item, *args, **kwargs): + """Patches recursive_will_execute function to cache the result of each output.""" + unique_id = current_item + inputs = prompt[unique_id]['inputs'] + will_execute = RgthreePatchRecursiveExecute_Set_patch_recursive_execution_to_false_if_not_working( + unique_id) + if unique_id in outputs: + return will_execute + + will_execute.add(1) + for x in inputs: + input_data = inputs[x] + if isinstance(input_data, list): + input_unique_id = input_data[0] + output_index = input_data[1] + node_output_cache_key = f'{input_unique_id}.{output_index}' + will_execute_value = None + # If this node's output has already been recursively evaluated, then we can reuse. + if node_output_cache_key in execution.rgthree_cache_recursive_will_execute: + will_execute_value = execution.rgthree_cache_recursive_will_execute[node_output_cache_key] + elif input_unique_id not in outputs: + will_execute_value = execution.recursive_will_execute(prompt, outputs, input_unique_id, + *args, **kwargs) + execution.rgthree_cache_recursive_will_execute[node_output_cache_key] = will_execute_value + if will_execute_value is not None: + will_execute.add(len(will_execute_value)) + return will_execute + + +def rgthree_recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item, *args, + **kwargs): + """Patches recursive_output_delete_if_changed function to cache the result of each output.""" + unique_id = current_item + inputs = prompt[unique_id]['inputs'] + class_type = prompt[unique_id]['class_type'] + class_def = execution.nodes.NODE_CLASS_MAPPINGS[class_type] + + is_changed_old = '' + is_changed = '' + to_delete = False + if hasattr(class_def, 'IS_CHANGED'): + if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]: + is_changed_old = old_prompt[unique_id]['is_changed'] + if 'is_changed' not in prompt[unique_id]: + input_data_all = execution.get_input_data(inputs, class_def, unique_id, outputs) + if input_data_all is not None: + try: + #is_changed = class_def.IS_CHANGED(**input_data_all) + is_changed = execution.map_node_over_list(class_def, input_data_all, "IS_CHANGED") + prompt[unique_id]['is_changed'] = is_changed + except: + to_delete = True + else: + is_changed = prompt[unique_id]['is_changed'] + + if unique_id not in outputs: + return True + + if not to_delete: + if is_changed != is_changed_old: + to_delete = True + elif unique_id not in old_prompt: + to_delete = True + elif inputs == old_prompt[unique_id]['inputs']: + for x in inputs: + input_data = inputs[x] + + if isinstance(input_data, list): + input_unique_id = input_data[0] + output_index = input_data[1] + node_output_cache_key = f'{input_unique_id}.{output_index}' + # If this node's output has already been recursively evaluated, then we can stop. + if node_output_cache_key in execution.rgthree_cache_recursive_output_delete_if_changed_output: + to_delete = execution.rgthree_cache_recursive_output_delete_if_changed_output[ + node_output_cache_key] + elif input_unique_id in outputs: + to_delete = execution.recursive_output_delete_if_changed(prompt, old_prompt, outputs, + input_unique_id, *args, + **kwargs) + execution.rgthree_cache_recursive_output_delete_if_changed_output[ + node_output_cache_key] = to_delete + else: + to_delete = True + if to_delete: + break + else: + to_delete = True + + if to_delete: + d = outputs.pop(unique_id) + del d + return to_delete + + +print() diff --git a/rgthree-comfy/__update_comfy__.py b/rgthree-comfy/__update_comfy__.py new file mode 100644 index 0000000000000000000000000000000000000000..262f1129b4b861fff40aa2566fc38715fcfefdfc --- /dev/null +++ b/rgthree-comfy/__update_comfy__.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 + +# A nicer output for git pulling custom nodes (and ComfyUI). +# Quick shell version: ls | xargs -I % sh -c 'echo; echo %; git -C % pull' + +import os +from subprocess import Popen, PIPE, STDOUT + + +def pull_path(path): + p = Popen(["git", "-C", path, "pull"], stdout=PIPE, stderr=STDOUT) + output, error = p.communicate() + return output.decode() + +THIS_DIR=os.path.dirname(os.path.abspath(__file__)) + +def show_output(output): + if output.startswith('Already up to date'): + print(f' \33[32m🗸 {output}\33[0m', end ='') + elif output.startswith('error:'): + print(f' \33[31m🞫 Error.\33[0m \n {output}') + else: + print(f' \33[33m🡅 Needs update.\33[0m \n {output}', end='') + + +os.chdir(THIS_DIR) +os.chdir("../") + +# Get the list or custom nodes, so we can format the output a little more nicely. +custom_extensions = [] +custom_extensions_name_max = 0 +for directory in os.listdir(os.getcwd()): + if os.path.isdir(directory) and directory != "__pycache__": #and directory != "rgthree-comfy" : + custom_extensions.append({ + 'directory': directory + }) + if len(directory) > custom_extensions_name_max: + custom_extensions_name_max = len(directory) + +if len(custom_extensions) == 0: + custom_extensions_name_max = 15 +else: + custom_extensions_name_max += 6 + +# Update ComfyUI itself. +label = "{0:.<{max}}".format('Updating ComfyUI ', max=custom_extensions_name_max) +print(label, end = '') +show_output(pull_path('../')) + +# If we have custom nodes, update them as well. +if len(custom_extensions) > 0: + print(f'\nUpdating custom_nodes ({len(custom_extensions)}):') + for custom_extension in custom_extensions: + directory = custom_extension['directory'] + label = "{0:.<{max}}".format(f'🗀 {directory} ', max=custom_extensions_name_max) + print(label, end = '') + show_output(pull_path(directory)) diff --git a/rgthree-comfy/docs/rgthree_advanced.png b/rgthree-comfy/docs/rgthree_advanced.png new file mode 100644 index 0000000000000000000000000000000000000000..49d85226499c791db835ddc5b5fd91c6621b8c70 --- /dev/null +++ b/rgthree-comfy/docs/rgthree_advanced.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:77d88a50847fa76d95a1470bf0315b035a06e3c87a8d4e8132d49d8d5b8a31ce +size 455664 diff --git a/rgthree-comfy/docs/rgthree_advanced_metadata.png b/rgthree-comfy/docs/rgthree_advanced_metadata.png new file mode 100644 index 0000000000000000000000000000000000000000..90aedcf8107dcadcd59519896af8bad9ae4399fa --- /dev/null +++ b/rgthree-comfy/docs/rgthree_advanced_metadata.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c33be251bd628225b9e29249b766b45539a62a9f94f2e0085b10c469f5ef0956 +size 491188 diff --git a/rgthree-comfy/docs/rgthree_context.png b/rgthree-comfy/docs/rgthree_context.png new file mode 100644 index 0000000000000000000000000000000000000000..583922074a628160d4d82f667fd0df54961f2a56 --- /dev/null +++ b/rgthree-comfy/docs/rgthree_context.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4d7c4401b0f4b6958d75b9eca531b0c16d8eeb121f3ec8092ccf1336966c43cd +size 544154 diff --git a/rgthree-comfy/docs/rgthree_context_metadata.png b/rgthree-comfy/docs/rgthree_context_metadata.png new file mode 100644 index 0000000000000000000000000000000000000000..cc9ee9116fd3f12848db064e5c4967743752d6ab --- /dev/null +++ b/rgthree-comfy/docs/rgthree_context_metadata.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0f89ef9decfee6b65831780a85235f6bfedd99ff1f16eaaf0567a17d94997ba7 +size 1689098 diff --git a/rgthree-comfy/docs/rgthree_router.png b/rgthree-comfy/docs/rgthree_router.png new file mode 100644 index 0000000000000000000000000000000000000000..f3581d708f98c6352dfcc414ac0bd81f2eb2fc59 Binary files /dev/null and b/rgthree-comfy/docs/rgthree_router.png differ diff --git a/rgthree-comfy/docs/rgthree_seed.png b/rgthree-comfy/docs/rgthree_seed.png new file mode 100644 index 0000000000000000000000000000000000000000..32c392fbccf38f7e0d150cbeed0cc9066f8ac610 Binary files /dev/null and b/rgthree-comfy/docs/rgthree_seed.png differ diff --git a/rgthree-comfy/package-lock.json b/rgthree-comfy/package-lock.json new file mode 100644 index 0000000000000000000000000000000000000000..d77bcc73035d7289f29e6607c189850bb7347746 --- /dev/null +++ b/rgthree-comfy/package-lock.json @@ -0,0 +1,261 @@ +{ + "name": "rgthree-comfy", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "rgthree-comfy", + "devDependencies": { + "prettier": "^3.3.3", + "sass": "^1.77.8", + "typescript": "^5.5.4" + } + }, + "node_modules/anymatch": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/anymatch/-/anymatch-3.1.3.tgz", + "integrity": "sha512-KMReFUr0B4t+D+OBkjR3KYqvocp2XaSzO55UcB6mgQMd3KbcE+mWTyvVV7D/zsdEbNnV6acZUutkiHQXvTr1Rw==", + "dev": true, + "dependencies": { + "normalize-path": "^3.0.0", + "picomatch": "^2.0.4" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/binary-extensions": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.2.0.tgz", + "integrity": "sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/braces": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", + "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", + "dev": true, + "dependencies": { + "fill-range": "^7.1.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/chokidar": { + "version": "3.5.3", + "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.5.3.tgz", + "integrity": "sha512-Dr3sfKRP6oTcjf2JmUmFJfeVMvXBdegxB0iVQ5eb2V10uFJUCAS8OByZdVAyVb8xXNz3GjjTgj9kLWsZTqE6kw==", + "dev": true, + "funding": [ + { + "type": "individual", + "url": "https://paulmillr.com/funding/" + } + ], + "dependencies": { + "anymatch": "~3.1.2", + "braces": "~3.0.2", + "glob-parent": "~5.1.2", + "is-binary-path": "~2.1.0", + "is-glob": "~4.0.1", + "normalize-path": "~3.0.0", + "readdirp": "~3.6.0" + }, + "engines": { + "node": ">= 8.10.0" + }, + "optionalDependencies": { + "fsevents": "~2.3.2" + } + }, + "node_modules/fill-range": { + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", + "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", + "dev": true, + "dependencies": { + "to-regex-range": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/fsevents": { + "version": "2.3.3", + "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", + "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", + "dev": true, + "hasInstallScript": true, + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^8.16.0 || ^10.6.0 || >=11.0.0" + } + }, + "node_modules/glob-parent": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", + "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", + "dev": true, + "dependencies": { + "is-glob": "^4.0.1" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/immutable": { + "version": "4.3.5", + "resolved": "https://registry.npmjs.org/immutable/-/immutable-4.3.5.tgz", + "integrity": "sha512-8eabxkth9gZatlwl5TBuJnCsoTADlL6ftEr7A4qgdaTsPyreilDSnUk57SO+jfKcNtxPa22U5KK6DSeAYhpBJw==", + "dev": true + }, + "node_modules/is-binary-path": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/is-binary-path/-/is-binary-path-2.1.0.tgz", + "integrity": "sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw==", + "dev": true, + "dependencies": { + "binary-extensions": "^2.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/is-extglob": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", + "integrity": "sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-glob": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", + "integrity": "sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==", + "dev": true, + "dependencies": { + "is-extglob": "^2.1.1" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-number": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", + "integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==", + "dev": true, + "engines": { + "node": ">=0.12.0" + } + }, + "node_modules/normalize-path": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/normalize-path/-/normalize-path-3.0.0.tgz", + "integrity": "sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/picomatch": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", + "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", + "dev": true, + "engines": { + "node": ">=8.6" + }, + "funding": { + "url": "https://github.com/sponsors/jonschlinkert" + } + }, + "node_modules/prettier": { + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.3.3.tgz", + "integrity": "sha512-i2tDNA0O5IrMO757lfrdQZCc2jPNDVntV0m/+4whiDfWaTKfMNgR7Qz0NAeGz/nRqF4m5/6CLzbP4/liHt12Ew==", + "dev": true, + "bin": { + "prettier": "bin/prettier.cjs" + }, + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/prettier/prettier?sponsor=1" + } + }, + "node_modules/readdirp": { + "version": "3.6.0", + "resolved": "https://registry.npmjs.org/readdirp/-/readdirp-3.6.0.tgz", + "integrity": "sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA==", + "dev": true, + "dependencies": { + "picomatch": "^2.2.1" + }, + "engines": { + "node": ">=8.10.0" + } + }, + "node_modules/sass": { + "version": "1.77.8", + "resolved": "https://registry.npmjs.org/sass/-/sass-1.77.8.tgz", + "integrity": "sha512-4UHg6prsrycW20fqLGPShtEvo/WyHRVRHwOP4DzkUrObWoWI05QBSfzU71TVB7PFaL104TwNaHpjlWXAZbQiNQ==", + "dev": true, + "dependencies": { + "chokidar": ">=3.0.0 <4.0.0", + "immutable": "^4.0.0", + "source-map-js": ">=0.6.2 <2.0.0" + }, + "bin": { + "sass": "sass.js" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/source-map-js": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.0.2.tgz", + "integrity": "sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/to-regex-range": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", + "integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==", + "dev": true, + "dependencies": { + "is-number": "^7.0.0" + }, + "engines": { + "node": ">=8.0" + } + }, + "node_modules/typescript": { + "version": "5.5.4", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.5.4.tgz", + "integrity": "sha512-Mtq29sKDAEYP7aljRgtPOpTvOfbwRWlS6dPRzwjdE+C0R4brX/GUyhHSecbHMFLNBLcJIPt9nl9yG5TZ1weH+Q==", + "dev": true, + "bin": { + "tsc": "bin/tsc", + "tsserver": "bin/tsserver" + }, + "engines": { + "node": ">=14.17" + } + } + } +} diff --git a/rgthree-comfy/package.json b/rgthree-comfy/package.json new file mode 100644 index 0000000000000000000000000000000000000000..b3ae282304f4df37c93b42e6cfea0d86420c5f88 --- /dev/null +++ b/rgthree-comfy/package.json @@ -0,0 +1,10 @@ +{ + "devDependencies": { + "prettier": "^3.3.3", + "typescript": "^5.5.4", + "sass": "^1.77.8" + }, + "scripts": { + "build": "./__build__.py || python .\\__build__.py" + } +} diff --git a/rgthree-comfy/prestartup_script.py b/rgthree-comfy/prestartup_script.py new file mode 100644 index 0000000000000000000000000000000000000000..47a9c2f8d8c73ebaa77a3c9149111e1edfd62a44 --- /dev/null +++ b/rgthree-comfy/prestartup_script.py @@ -0,0 +1,4 @@ +import folder_paths + +# Add 'saved_prompts' as a folder for Power Prompt node. +folder_paths.folder_names_and_paths['saved_prompts'] = ([], set(['.txt'])) diff --git a/rgthree-comfy/py/__init__.py b/rgthree-comfy/py/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/rgthree-comfy/py/any_switch.py b/rgthree-comfy/py/any_switch.py new file mode 100644 index 0000000000000000000000000000000000000000..63edba85cf463399c9337d6c33473bef216430bb --- /dev/null +++ b/rgthree-comfy/py/any_switch.py @@ -0,0 +1,38 @@ +from .context_utils import is_context_empty +from .constants import get_category, get_name +from .utils import FlexibleOptionalInputType, any_type + + +def is_none(value): + """Checks if a value is none. Pulled out in case we want to expand what 'None' means.""" + if value is not None: + if isinstance(value, dict) and 'model' in value and 'clip' in value: + return is_context_empty(value) + return value is None + + +class RgthreeAnySwitch: + """The dynamic Any Switch. """ + + NAME = get_name("Any Switch") + CATEGORY = get_category() + + @classmethod + def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring + return { + "required": {}, + "optional": FlexibleOptionalInputType(any_type), + } + + RETURN_TYPES = (any_type,) + RETURN_NAMES = ('*',) + FUNCTION = "switch" + + def switch(self, **kwargs): + """Chooses the first non-empty item to output.""" + any_value = None + for key, value in kwargs.items(): + if key.startswith('any_') and not is_none(value): + any_value = value + break + return (any_value,) diff --git a/rgthree-comfy/py/config.py b/rgthree-comfy/py/config.py new file mode 100644 index 0000000000000000000000000000000000000000..04214d0babc38f3718f088a1fc265f1dc62b1da4 --- /dev/null +++ b/rgthree-comfy/py/config.py @@ -0,0 +1,91 @@ +import os +import json +import re + +from .utils import get_dict_value, set_dict_value, dict_has_key, load_json_file + + +def get_config_value(key): + return get_dict_value(RGTHREE_CONFIG, key) + + +def extend_config(default_config, user_config): + """ Returns a new config dict combining user_config into defined keys for default_config.""" + cfg = {} + for key, value in default_config.items(): + if key not in user_config: + cfg[key] = value + elif isinstance(value, dict): + cfg[key] = extend_config(value, user_config[key]) + else: + cfg[key] = user_config[key] if key in user_config else value + return cfg + + +def set_user_config(data: dict): + """ Sets the user configuration.""" + count = 0 + for key, value in data.items(): + if dict_has_key(DEFAULT_CONFIG, key): + set_dict_value(USER_CONFIG, key, value) + set_dict_value(RGTHREE_CONFIG, key, value) + count += 1 + if count > 0: + write_user_config() + + +def get_rgthree_default_config(): + """ Gets the default configuration.""" + return load_json_file(DEFAULT_CONFIG_FILE, default={}) + + +def get_rgthree_user_config(): + """ Gets the user configuration.""" + return load_json_file(USER_CONFIG_FILE, default={}) + + +def write_user_config(): + """ Writes the user configuration.""" + with open(USER_CONFIG_FILE, 'w+', encoding='UTF-8') as file: + json.dump(USER_CONFIG, file, sort_keys=True, indent=2, separators=(",", ": ")) + + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +DEFAULT_CONFIG_FILE = os.path.join(THIS_DIR, '..', 'rgthree_config.json.default') +USER_CONFIG_FILE = os.path.join(THIS_DIR, '..', 'rgthree_config.json') +DEFAULT_CONFIG = get_rgthree_default_config() + +USER_CONFIG = get_rgthree_user_config() + +# Migrate old config options into "features" +needs_to_write_user_config = False +if 'patch_recursive_execution' in USER_CONFIG: + if 'features' not in USER_CONFIG: + USER_CONFIG['features'] = {} + USER_CONFIG['features']['patch_recursive_execution'] = USER_CONFIG['patch_recursive_execution'] + del USER_CONFIG['patch_recursive_execution'] + needs_to_write_user_config = True + +if 'show_alerts_for_corrupt_workflows' in USER_CONFIG: + if 'features' not in USER_CONFIG: + USER_CONFIG['features'] = {} + USER_CONFIG['features']['show_alerts_for_corrupt_workflows'] = USER_CONFIG[ + 'show_alerts_for_corrupt_workflows'] + del USER_CONFIG['show_alerts_for_corrupt_workflows'] + needs_to_write_user_config = True + +if 'monitor_for_corrupt_links' in USER_CONFIG: + if 'features' not in USER_CONFIG: + USER_CONFIG['features'] = {} + USER_CONFIG['features']['monitor_for_corrupt_links'] = USER_CONFIG['monitor_for_corrupt_links'] + del USER_CONFIG['monitor_for_corrupt_links'] + needs_to_write_user_config = True + +if needs_to_write_user_config is True: + print('writing new user config.') + write_user_config() + +RGTHREE_CONFIG = extend_config(DEFAULT_CONFIG, USER_CONFIG) + +if "unreleased" in USER_CONFIG and "unreleased" not in RGTHREE_CONFIG: + RGTHREE_CONFIG["unreleased"] = USER_CONFIG["unreleased"] diff --git a/rgthree-comfy/py/constants.py b/rgthree-comfy/py/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..e5aba5ede7118787d1a34294b58e1680f6e5e718 --- /dev/null +++ b/rgthree-comfy/py/constants.py @@ -0,0 +1,11 @@ + +NAMESPACE='rgthree' + +def get_name(name): + return '{} ({})'.format(name, NAMESPACE) + +def get_category(sub_dirs = None): + if sub_dirs is None: + return NAMESPACE + else: + return "{}/utils".format(NAMESPACE) diff --git a/rgthree-comfy/py/context.py b/rgthree-comfy/py/context.py new file mode 100644 index 0000000000000000000000000000000000000000..25ae7f465d5c8d441e0a146d16fec1e52776470d --- /dev/null +++ b/rgthree-comfy/py/context.py @@ -0,0 +1,33 @@ +"""The Context node.""" +from .context_utils import (ORIG_CTX_OPTIONAL_INPUTS, ORIG_CTX_RETURN_NAMES, ORIG_CTX_RETURN_TYPES, + get_orig_context_return_tuple, new_context) +from .constants import get_category, get_name + + +class RgthreeContext: + """The initial Context node. + + For now, this nodes' outputs will remain as-is, as they are perfect for most 1.5 application, but + is also backwards compatible with other Context nodes. + """ + + NAME = get_name("Context") + CATEGORY = get_category() + + @classmethod + def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring + return { + "required": {}, + "optional": ORIG_CTX_OPTIONAL_INPUTS, + "hidden": { + "version": "FLOAT" + }, + } + + RETURN_TYPES = ORIG_CTX_RETURN_TYPES + RETURN_NAMES = ORIG_CTX_RETURN_NAMES + FUNCTION = "convert" + + def convert(self, base_ctx=None, **kwargs): # pylint: disable = missing-function-docstring + ctx = new_context(base_ctx, **kwargs) + return get_orig_context_return_tuple(ctx) diff --git a/rgthree-comfy/py/context_big.py b/rgthree-comfy/py/context_big.py new file mode 100644 index 0000000000000000000000000000000000000000..301b11d7fdca9bcfab21d44599129e6a98a8ea9c --- /dev/null +++ b/rgthree-comfy/py/context_big.py @@ -0,0 +1,31 @@ +"""The Conmtext big node.""" +from .constants import get_category, get_name +from .context_utils import (ALL_CTX_OPTIONAL_INPUTS, ALL_CTX_RETURN_NAMES, ALL_CTX_RETURN_TYPES, + new_context, get_context_return_tuple) + + +class RgthreeBigContext: + """The Context Big node. + + This context node will expose all context fields as inputs and outputs. It is backwards compatible + with other context nodes and can be intertwined with them. + """ + + NAME = get_name("Context Big") + CATEGORY = get_category() + + @classmethod + def INPUT_TYPES(cls): # pylint: disable = invalid-name,missing-function-docstring + return { + "required": {}, + "optional": ALL_CTX_OPTIONAL_INPUTS, + "hidden": {}, + } + + RETURN_TYPES = ALL_CTX_RETURN_TYPES + RETURN_NAMES = ALL_CTX_RETURN_NAMES + FUNCTION = "convert" + + def convert(self, base_ctx=None, **kwargs): # pylint: disable = missing-function-docstring + ctx = new_context(base_ctx, **kwargs) + return get_context_return_tuple(ctx) diff --git a/rgthree-comfy/py/context_merge.py b/rgthree-comfy/py/context_merge.py new file mode 100644 index 0000000000000000000000000000000000000000..be1642f9f383b1832db91c0ce2d5d51a8120e28f --- /dev/null +++ b/rgthree-comfy/py/context_merge.py @@ -0,0 +1,37 @@ +"""The Context Switch (Big).""" +from .constants import get_category, get_name +from .context_utils import (ORIG_CTX_RETURN_TYPES, ORIG_CTX_RETURN_NAMES, merge_new_context, + get_orig_context_return_tuple, is_context_empty) +from .utils import FlexibleOptionalInputType + + +class RgthreeContextMerge: + """The Context Merge node.""" + + NAME = get_name("Context Merge") + CATEGORY = get_category() + + @classmethod + def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring + return { + "required": {}, + "optional": FlexibleOptionalInputType("RGTHREE_CONTEXT"), + } + + RETURN_TYPES = ORIG_CTX_RETURN_TYPES + RETURN_NAMES = ORIG_CTX_RETURN_NAMES + FUNCTION = "merge" + + def get_return_tuple(self, ctx): + """Returns the context data. Separated so it can be overridden.""" + return get_orig_context_return_tuple(ctx) + + def merge(self, **kwargs): + """Merges any non-null passed contexts; later ones overriding earlier.""" + ctxs = [ + value for key, value in kwargs.items() + if key.startswith('ctx_') and not is_context_empty(value) + ] + ctx = merge_new_context(*ctxs) + + return self.get_return_tuple(ctx) diff --git a/rgthree-comfy/py/context_merge_big.py b/rgthree-comfy/py/context_merge_big.py new file mode 100644 index 0000000000000000000000000000000000000000..2f6c8ff4dfbecd31de89a4b8c9726c757e58d74b --- /dev/null +++ b/rgthree-comfy/py/context_merge_big.py @@ -0,0 +1,16 @@ +"""The Context Switch (Big).""" +from .constants import get_category, get_name +from .context_utils import (ALL_CTX_RETURN_TYPES, ALL_CTX_RETURN_NAMES, get_context_return_tuple) +from .context_merge import RgthreeContextMerge + + +class RgthreeContextMergeBig(RgthreeContextMerge): + """The Context Merge Big node.""" + + NAME = get_name("Context Merge Big") + RETURN_TYPES = ALL_CTX_RETURN_TYPES + RETURN_NAMES = ALL_CTX_RETURN_NAMES + + def get_return_tuple(self, ctx): + """Returns the context data. Separated so it can be overridden.""" + return get_context_return_tuple(ctx) diff --git a/rgthree-comfy/py/context_switch.py b/rgthree-comfy/py/context_switch.py new file mode 100644 index 0000000000000000000000000000000000000000..c112f16ead8d94e42a2c033688c3ce0eff37be77 --- /dev/null +++ b/rgthree-comfy/py/context_switch.py @@ -0,0 +1,36 @@ +"""The original Context Switch.""" +from .constants import get_category, get_name +from .context_utils import (ORIG_CTX_RETURN_TYPES, ORIG_CTX_RETURN_NAMES, is_context_empty, + get_orig_context_return_tuple) +from .utils import FlexibleOptionalInputType + + +class RgthreeContextSwitch: + """The (original) Context Switch node.""" + + NAME = get_name("Context Switch") + CATEGORY = get_category() + + @classmethod + def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring + return { + "required": {}, + "optional": FlexibleOptionalInputType("RGTHREE_CONTEXT"), + } + + RETURN_TYPES = ORIG_CTX_RETURN_TYPES + RETURN_NAMES = ORIG_CTX_RETURN_NAMES + FUNCTION = "switch" + + def get_return_tuple(self, ctx): + """Returns the context data. Separated so it can be overridden.""" + return get_orig_context_return_tuple(ctx) + + def switch(self, **kwargs): + """Chooses the first non-empty Context to output.""" + ctx = None + for key, value in kwargs.items(): + if key.startswith('ctx_') and not is_context_empty(value): + ctx = value + break + return self.get_return_tuple(ctx) diff --git a/rgthree-comfy/py/context_switch_big.py b/rgthree-comfy/py/context_switch_big.py new file mode 100644 index 0000000000000000000000000000000000000000..ae8248cd8aa91d8c75d186e15f3edbb8807ae281 --- /dev/null +++ b/rgthree-comfy/py/context_switch_big.py @@ -0,0 +1,16 @@ +"""The Context Switch (Big).""" +from .constants import get_category, get_name +from .context_utils import (ALL_CTX_RETURN_TYPES, ALL_CTX_RETURN_NAMES, get_context_return_tuple) +from .context_switch import RgthreeContextSwitch + + +class RgthreeContextSwitchBig(RgthreeContextSwitch): + """The Context Switch Big node.""" + + NAME = get_name("Context Switch Big") + RETURN_TYPES = ALL_CTX_RETURN_TYPES + RETURN_NAMES = ALL_CTX_RETURN_NAMES + + def get_return_tuple(self, ctx): + """Overrides the RgthreeContextSwitch `get_return_tuple` to return big context data.""" + return get_context_return_tuple(ctx) diff --git a/rgthree-comfy/py/context_utils.py b/rgthree-comfy/py/context_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5a067c8341fefc9c7a78ca6fc510ad07c1a18016 --- /dev/null +++ b/rgthree-comfy/py/context_utils.py @@ -0,0 +1,118 @@ +"""A set of constants and utilities for handling contexts. + +Sets up the inputs and outputs for the Context going forward, with additional functions for +creating and exporting context objects. +""" +import comfy.samplers +import folder_paths + +_all_context_input_output_data = { + "base_ctx": ("base_ctx", "RGTHREE_CONTEXT", "CONTEXT"), + "model": ("model", "MODEL", "MODEL"), + "clip": ("clip", "CLIP", "CLIP"), + "vae": ("vae", "VAE", "VAE"), + "positive": ("positive", "CONDITIONING", "POSITIVE"), + "negative": ("negative", "CONDITIONING", "NEGATIVE"), + "latent": ("latent", "LATENT", "LATENT"), + "images": ("images", "IMAGE", "IMAGE"), + "seed": ("seed", "INT", "SEED"), + "steps": ("steps", "INT", "STEPS"), + "step_refiner": ("step_refiner", "INT", "STEP_REFINER"), + "cfg": ("cfg", "FLOAT", "CFG"), + "ckpt_name": ("ckpt_name", folder_paths.get_filename_list("checkpoints"), "CKPT_NAME"), + "sampler": ("sampler", comfy.samplers.KSampler.SAMPLERS, "SAMPLER"), + "scheduler": ("scheduler", comfy.samplers.KSampler.SCHEDULERS, "SCHEDULER"), + "clip_width": ("clip_width", "INT", "CLIP_WIDTH"), + "clip_height": ("clip_height", "INT", "CLIP_HEIGHT"), + "text_pos_g": ("text_pos_g", "STRING", "TEXT_POS_G"), + "text_pos_l": ("text_pos_l", "STRING", "TEXT_POS_L"), + "text_neg_g": ("text_neg_g", "STRING", "TEXT_NEG_G"), + "text_neg_l": ("text_neg_l", "STRING", "TEXT_NEG_L"), + "mask": ("mask", "MASK", "MASK"), + "control_net": ("control_net", "CONTROL_NET", "CONTROL_NET"), +} + +force_input_types = ["INT", "STRING", "FLOAT"] +force_input_names = ["sampler", "scheduler", "ckpt_name"] + + +def _create_context_data(input_list=None): + """Returns a tuple of context inputs, return types, and return names to use in a node"s def""" + if input_list is None: + input_list = _all_context_input_output_data.keys() + list_ctx_return_types = [] + list_ctx_return_names = [] + ctx_optional_inputs = {} + for inp in input_list: + data = _all_context_input_output_data[inp] + list_ctx_return_types.append(data[1]) + list_ctx_return_names.append(data[2]) + ctx_optional_inputs[data[0]] = tuple([data[1]] + ([{ + "forceInput": True + }] if data[1] in force_input_types or data[0] in force_input_names else [])) + + ctx_return_types = tuple(list_ctx_return_types) + ctx_return_names = tuple(list_ctx_return_names) + return (ctx_optional_inputs, ctx_return_types, ctx_return_names) + + +ALL_CTX_OPTIONAL_INPUTS, ALL_CTX_RETURN_TYPES, ALL_CTX_RETURN_NAMES = _create_context_data() + +_original_ctx_inputs_list = [ + "base_ctx", "model", "clip", "vae", "positive", "negative", "latent", "images", "seed" +] +ORIG_CTX_OPTIONAL_INPUTS, ORIG_CTX_RETURN_TYPES, ORIG_CTX_RETURN_NAMES = _create_context_data( + _original_ctx_inputs_list) + + +def new_context(base_ctx, **kwargs): + """Creates a new context from the provided data, with an optional base ctx to start.""" + context = base_ctx if base_ctx is not None else None + new_ctx = {} + for key in _all_context_input_output_data: + if key == "base_ctx": + continue + v = kwargs[key] if key in kwargs else None + new_ctx[key] = v if v is not None else context[ + key] if context is not None and key in context else None + return new_ctx + + +def merge_new_context(*args): + """Creates a new context by merging provided contexts with the latter overriding same fields.""" + new_ctx = {} + for key in _all_context_input_output_data: + if key == "base_ctx": + continue + v = None + # Move backwards through the passed contexts until we find a value and use it. + for ctx in reversed(args): + v = ctx[key] if not is_context_empty(ctx) and key in ctx else None + if v is not None: + break + new_ctx[key] = v + return new_ctx + + +def get_context_return_tuple(ctx, inputs_list=None): + """Returns a tuple for returning in the order of the inputs list.""" + if inputs_list is None: + inputs_list = _all_context_input_output_data.keys() + tup_list = [ + ctx, + ] + for key in inputs_list: + if key == "base_ctx": + continue + tup_list.append(ctx[key] if ctx is not None and key in ctx else None) + return tuple(tup_list) + + +def get_orig_context_return_tuple(ctx): + """Returns a tuple for returning from a node with only the original context keys.""" + return get_context_return_tuple(ctx, _original_ctx_inputs_list) + + +def is_context_empty(ctx): + """Checks if the provided ctx is None or contains just None values.""" + return not ctx or all(v is None for v in ctx.values()) diff --git a/rgthree-comfy/py/display_any.py b/rgthree-comfy/py/display_any.py new file mode 100644 index 0000000000000000000000000000000000000000..522721dc2c714dce5ac9e386549c7147dc306582 --- /dev/null +++ b/rgthree-comfy/py/display_any.py @@ -0,0 +1,71 @@ +import json +from .constants import get_category, get_name + + +class AnyType(str): + """A special class that is always equal in not equal comparisons. Credit to pythongosssss""" + + def __ne__(self, __value: object) -> bool: + return False + + +any = AnyType("*") + + +class RgthreeDisplayAny: + """Display any data node.""" + + NAME = get_name('Display Any') + CATEGORY = get_category() + + @classmethod + def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring + return { + "required": { + "source": (any, {}), + }, + } + + RETURN_TYPES = () + FUNCTION = "main" + OUTPUT_NODE = True + + def main(self, source=None): + value = 'None' + if source is not None: + try: + value = json.dumps(source) + except Exception: + try: + value = str(source) + except Exception: + value = 'source exists, but could not be serialized.' + + return {"ui": {"text": (value,)}} + + +class RgthreeDisplayInt: + """Old DisplayInt node. + + Can be ported over to DisplayAny if https://github.com/comfyanonymous/ComfyUI/issues/1527 fixed. + """ + + NAME = get_name('Display Int') + CATEGORY = get_category() + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "input": ("INT", { + "forceInput": True + }), + }, + } + + RETURN_TYPES = () + FUNCTION = "main" + OUTPUT_NODE = True + + def main(self, input=None): + return {"ui": {"text": (input,)}} diff --git a/rgthree-comfy/py/dynamic_context.py b/rgthree-comfy/py/dynamic_context.py new file mode 100644 index 0000000000000000000000000000000000000000..411210d0e0cbc3c038b8b9c0822b69658fc63c7d --- /dev/null +++ b/rgthree-comfy/py/dynamic_context.py @@ -0,0 +1,56 @@ +"""The Dynamic Context node.""" +from mimetypes import add_type +from .constants import get_category, get_name +from .utils import ByPassTypeTuple, FlexibleOptionalInputType + + +class RgthreeDynamicContext: + """The Dynamic Context node. + + Similar to the static Context and Context Big nodes, this allows users to add any number and + variety of inputs to a Dynamic Context node, and return the outputs by key name. + """ + + NAME = get_name("Dynamic Context") + CATEGORY = get_category() + + @classmethod + def INPUT_TYPES(cls): # pylint: disable = invalid-name,missing-function-docstring + return { + "required": {}, + "optional": FlexibleOptionalInputType(add_type), + "hidden": {}, + } + + RETURN_TYPES = ByPassTypeTuple(("RGTHREE_DYNAMIC_CONTEXT",)) + RETURN_NAMES = ByPassTypeTuple(("CONTEXT",)) + FUNCTION = "main" + + def main(self, **kwargs): + """Creates a new context from the provided data, with an optional base ctx to start. + + This node takes a list of named inputs that are the named keys (with an optional "+ " prefix) + which are to be stored within the ctx dict as well as a list of keys contained in `output_keys` + to determine the list of output data. + """ + + base_ctx = kwargs.get('base_ctx', None) + output_keys = kwargs.get('output_keys', None) + + new_ctx = base_ctx.copy() if base_ctx is not None else {} + + for key_raw, value in kwargs.items(): + if key_raw in ['base_ctx', 'output_keys']: + continue + key = key_raw.upper() + if key.startswith('+ '): + key = key[2:] + new_ctx[key] = value + + print(new_ctx) + + res = [new_ctx] + output_keys = output_keys.split(',') if output_keys is not None else [] + for key in output_keys: + res.append(new_ctx[key] if key in new_ctx else None) + return tuple(res) diff --git a/rgthree-comfy/py/dynamic_context_switch.py b/rgthree-comfy/py/dynamic_context_switch.py new file mode 100644 index 0000000000000000000000000000000000000000..3d1b9e6a7eb1a592db3941298e57ff0c38dfb5e0 --- /dev/null +++ b/rgthree-comfy/py/dynamic_context_switch.py @@ -0,0 +1,39 @@ +"""The original Context Switch.""" +from .constants import get_category, get_name +from .context_utils import is_context_empty +from .utils import ByPassTypeTuple, FlexibleOptionalInputType + + +class RgthreeDynamicContextSwitch: + """The initial Context Switch node.""" + + NAME = get_name("Dynamic Context Switch") + CATEGORY = get_category() + + @classmethod + def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring + return { + "required": {}, + "optional": FlexibleOptionalInputType("RGTHREE_DYNAMIC_CONTEXT"), + } + + RETURN_TYPES = ByPassTypeTuple(("RGTHREE_DYNAMIC_CONTEXT",)) + RETURN_NAMES = ByPassTypeTuple(("CONTEXT",)) + FUNCTION = "switch" + + def switch(self, **kwargs): + """Chooses the first non-empty Context to output.""" + + output_keys = kwargs.get('output_keys', None) + + ctx = None + for key, value in kwargs.items(): + if key.startswith('ctx_') and not is_context_empty(value): + ctx = value + break + + res = [ctx] + output_keys = output_keys.split(',') if output_keys is not None else [] + for key in output_keys: + res.append(ctx[key] if ctx is not None and key in ctx else None) + return tuple(res) diff --git a/rgthree-comfy/py/image_comparer.py b/rgthree-comfy/py/image_comparer.py new file mode 100644 index 0000000000000000000000000000000000000000..ab4b1a634fd1cf06b39736507ca327ee4763d91e --- /dev/null +++ b/rgthree-comfy/py/image_comparer.py @@ -0,0 +1,41 @@ +from nodes import PreviewImage + +from .constants import get_category, get_name + + +class RgthreeImageComparer(PreviewImage): + """A node that compares two images in the UI.""" + + NAME = get_name('Image Comparer') + CATEGORY = get_category() + FUNCTION = "compare_images" + + @classmethod + def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring + return { + "required": {}, + "optional": { + "image_a": ("IMAGE",), + "image_b": ("IMAGE",), + }, + "hidden": { + "prompt": "PROMPT", + "extra_pnginfo": "EXTRA_PNGINFO" + }, + } + + def compare_images(self, + image_a=None, + image_b=None, + filename_prefix="rgthree.compare.", + prompt=None, + extra_pnginfo=None): + + result = { "ui": { "a_images":[], "b_images": [] } } + if image_a is not None and len(image_a) > 0: + result['ui']['a_images'] = self.save_images(image_a, filename_prefix, prompt, extra_pnginfo)['ui']['images'] + + if image_b is not None and len(image_b) > 0: + result['ui']['b_images'] = self.save_images(image_b, filename_prefix, prompt, extra_pnginfo)['ui']['images'] + + return result \ No newline at end of file diff --git a/rgthree-comfy/py/image_inset_crop.py b/rgthree-comfy/py/image_inset_crop.py new file mode 100644 index 0000000000000000000000000000000000000000..6e672eaa33c43c31520d72d833a90886d44ac888 --- /dev/null +++ b/rgthree-comfy/py/image_inset_crop.py @@ -0,0 +1,93 @@ +"""Image Inset Crop, with percentages.""" +from .log import log_node_info +from .constants import get_category, get_name +from nodes import MAX_RESOLUTION + + +def get_new_bounds(width, height, left, right, top, bottom): + """Returns the new bounds for an image with inset crop data.""" + left = 0 + left + right = width - right + top = 0 + top + bottom = height - bottom + return (left, right, top, bottom) + + +class RgthreeImageInsetCrop: + """Image Inset Crop, with percentages.""" + + NAME = get_name('Image Inset Crop') + CATEGORY = get_category() + + @classmethod + def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring + return { + "required": { + "image": ("IMAGE",), + "measurement": (['Pixels', 'Percentage'],), + "left": ("INT", { + "default": 0, + "min": 0, + "max": MAX_RESOLUTION, + "step": 8 + }), + "right": ("INT", { + "default": 0, + "min": 0, + "max": MAX_RESOLUTION, + "step": 8 + }), + "top": ("INT", { + "default": 0, + "min": 0, + "max": MAX_RESOLUTION, + "step": 8 + }), + "bottom": ("INT", { + "default": 0, + "min": 0, + "max": MAX_RESOLUTION, + "step": 8 + }), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "crop" + + # pylint: disable = too-many-arguments + def crop(self, measurement, left, right, top, bottom, image=None): + """Does the crop.""" + + _, height, width, _ = image.shape + + if measurement == 'Percentage': + left = int(width - (width * (100 - left) / 100)) + right = int(width - (width * (100 - right) / 100)) + top = int(height - (height * (100 - top) / 100)) + bottom = int(height - (height * (100 - bottom) / 100)) + + # Snap to 8 pixels + left = left // 8 * 8 + right = right // 8 * 8 + top = top // 8 * 8 + bottom = bottom // 8 * 8 + + if left == 0 and right == 0 and bottom == 0 and top == 0: + return (image,) + + inset_left, inset_right, inset_top, inset_bottom = get_new_bounds(width, height, left, right, + top, bottom) + if inset_top > inset_bottom: + raise ValueError( + f"Invalid cropping dimensions top ({inset_top}) exceeds bottom ({inset_bottom})") + if inset_left > inset_right: + raise ValueError( + f"Invalid cropping dimensions left ({inset_left}) exceeds right ({inset_right})") + + log_node_info( + self.NAME, f'Cropping image {width}x{height} width inset by {inset_left},{inset_right}, ' + + f'and height inset by {inset_top}, {inset_bottom}') + image = image[:, inset_top:inset_bottom, inset_left:inset_right, :] + + return (image,) diff --git a/rgthree-comfy/py/ksampler_config.py b/rgthree-comfy/py/ksampler_config.py new file mode 100644 index 0000000000000000000000000000000000000000..b971007d3d4a1f01bd766784082c441f024acaaa --- /dev/null +++ b/rgthree-comfy/py/ksampler_config.py @@ -0,0 +1,56 @@ +"""Some basic config stuff I use for SDXL.""" + +from .constants import get_category, get_name +from nodes import MAX_RESOLUTION +import comfy.samplers + + +class RgthreeKSamplerConfig: + """Some basic config stuff I started using for SDXL, but useful in other spots too.""" + + NAME = get_name('KSampler Config') + CATEGORY = get_category() + + @classmethod + def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring + return { + "required": { + "steps_total": ("INT", { + "default": 30, + "min": 1, + "max": MAX_RESOLUTION, + "step": 1, + }), + "refiner_step": ("INT", { + "default": 24, + "min": 1, + "max": MAX_RESOLUTION, + "step": 1, + }), + "cfg": ("FLOAT", { + "default": 8.0, + "min": 0.0, + "max": 100.0, + "step": 0.5, + }), + "sampler_name": (comfy.samplers.KSampler.SAMPLERS,), + "scheduler": (comfy.samplers.KSampler.SCHEDULERS,), + #"refiner_ascore_pos": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}), + #"refiner_ascore_neg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}), + }, + } + + RETURN_TYPES = ("INT", "INT", "FLOAT", comfy.samplers.KSampler.SAMPLERS, + comfy.samplers.KSampler.SCHEDULERS) + RETURN_NAMES = ("STEPS", "REFINER_STEP", "CFG", "SAMPLER", "SCHEDULER") + FUNCTION = "main" + + def main(self, steps_total, refiner_step, cfg, sampler_name, scheduler): + """main""" + return ( + steps_total, + refiner_step, + cfg, + sampler_name, + scheduler, + ) diff --git a/rgthree-comfy/py/log.py b/rgthree-comfy/py/log.py new file mode 100644 index 0000000000000000000000000000000000000000..6d26aead2222d8920fb489f1ede97e97b2d13870 --- /dev/null +++ b/rgthree-comfy/py/log.py @@ -0,0 +1,81 @@ +# https://stackoverflow.com/questions/4842424/list-of-ansi-color-escape-sequences +# https://en.wikipedia.org/wiki/ANSI_escape_code#3-bit_and_4-bit +COLORS = { + 'BLACK': '\33[30m', + 'RED': '\33[31m', + 'GREEN': '\33[32m', + 'YELLOW': '\33[33m', + 'BLUE': '\33[34m', + 'MAGENTA': '\33[35m', + 'CYAN': '\33[36m', + 'WHITE': '\33[37m', + 'GREY': '\33[90m', + 'BRIGHT_RED': '\33[91m', + 'BRIGHT_GREEN': '\33[92m', + 'BRIGHT_YELLOW': '\33[93m', + 'BRIGHT_BLUE': '\33[94m', + 'BRIGHT_MAGENTA': '\33[95m', + 'BRIGHT_CYAN': '\33[96m', + 'BRIGHT_WHITE': '\33[97m', + # Styles. + 'RESET': '\33[00m', + 'BOLD': '\33[01m', + 'NORMAL': '\33[22m', + 'ITALIC': '\33[03m', + 'UNDERLINE': '\33[04m', + 'BLINK': '\33[05m', + 'BLINK2': '\33[06m', + 'SELECTED': '\33[07m', + # Backgrounds + 'BG_BLACK': '\33[40m', + 'BG_RED': '\33[41m', + 'BG_GREEN': '\33[42m', + 'BG_YELLOW': '\33[43m', + 'BG_BLUE': '\33[44m', + 'BG_MAGENTA': '\33[45m', + 'BG_CYAN': '\33[46m', + 'BG_WHITE': '\33[47m', + 'BG_GREY': '\33[100m', + 'BG_BRIGHT_RED': '\33[101m', + 'BG_BRIGHT_GREEN': '\33[102m', + 'BG_BRIGHT_YELLOW': '\33[103m', + 'BG_BRIGHT_BLUE': '\33[104m', + 'BG_BRIGHT_MAGENTA': '\33[105m', + 'BG_BRIGHT_CYAN': '\33[106m', + 'BG_BRIGHT_WHITE': '\33[107m', +} + + +def log_node_success(node_name, message, msg_color='RESET'): + """Logs a success message.""" + _log_node("BRIGHT_GREEN", node_name, message, msg_color=msg_color) + + +def log_node_info(node_name, message, msg_color='RESET'): + """Logs an info message.""" + _log_node("CYAN", node_name, message, msg_color=msg_color) + + +def log_node_warn(node_name, message, msg_color='RESET'): + """Logs an warn message.""" + _log_node("YELLOW", node_name, message, msg_color=msg_color) + + +def log_node(node_name, message, msg_color='RESET'): + """Logs a message.""" + _log_node("CYAN", node_name, message, msg_color=msg_color) + + +def _log_node(color, node_name, message, msg_color='RESET'): + """Logs for a node message.""" + log(message, color=color, prefix=node_name.replace(" (rgthree)", ""), msg_color=msg_color) + + +def log(message, color=None, msg_color=None, prefix=None): + """Basic logging.""" + color = COLORS[color] if color is not None and color in COLORS else COLORS["BRIGHT_GREEN"] + msg_color = COLORS[msg_color] if msg_color is not None and msg_color in COLORS else '' + prefix = f'[{prefix}]' if prefix is not None else '' + msg = f'{color}[rgthree]{prefix}' + msg += f'{msg_color} {message}{COLORS["RESET"]}' + print(msg) diff --git a/rgthree-comfy/py/lora_stack.py b/rgthree-comfy/py/lora_stack.py new file mode 100644 index 0000000000000000000000000000000000000000..420ae1390ad6de114424efaba45b7c04853af312 --- /dev/null +++ b/rgthree-comfy/py/lora_stack.py @@ -0,0 +1,46 @@ +from .constants import get_category, get_name +from nodes import LoraLoader +import folder_paths + + +class RgthreeLoraLoaderStack: + + NAME = get_name('Lora Loader Stack') + CATEGORY = get_category() + + @classmethod + def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring + return { + "required": { + "model": ("MODEL",), + "clip": ("CLIP", ), + + "lora_01": (['None'] + folder_paths.get_filename_list("loras"), ), + "strength_01":("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + + "lora_02": (['None'] + folder_paths.get_filename_list("loras"), ), + "strength_02":("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + + "lora_03": (['None'] + folder_paths.get_filename_list("loras"), ), + "strength_03":("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + + "lora_04": (['None'] + folder_paths.get_filename_list("loras"), ), + "strength_04":("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + } + } + + RETURN_TYPES = ("MODEL", "CLIP") + FUNCTION = "load_lora" + + def load_lora(self, model, clip, lora_01, strength_01, lora_02, strength_02, lora_03, strength_03, lora_04, strength_04): + if lora_01 != "None" and strength_01 != 0: + model, clip = LoraLoader().load_lora(model, clip, lora_01, strength_01, strength_01) + if lora_02 != "None" and strength_02 != 0: + model, clip = LoraLoader().load_lora(model, clip, lora_02, strength_02, strength_02) + if lora_03 != "None" and strength_03 != 0: + model, clip = LoraLoader().load_lora(model, clip, lora_03, strength_03, strength_03) + if lora_04 != "None" and strength_04 != 0: + model, clip = LoraLoader().load_lora(model, clip, lora_04, strength_04, strength_04) + + return (model, clip) + diff --git a/rgthree-comfy/py/power_lora_loader.py b/rgthree-comfy/py/power_lora_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..12969d81c8cfc2464ad92f7a2b0b68e1e90d5012 --- /dev/null +++ b/rgthree-comfy/py/power_lora_loader.py @@ -0,0 +1,44 @@ +from nodes import LoraLoader +from .constants import get_category, get_name +from .power_prompt_utils import get_lora_by_filename +from .utils import FlexibleOptionalInputType, any_type + + +class RgthreePowerLoraLoader: + """ The Power Lora Loader is a powerful, flexible node to add multiple loras to a model/clip.""" + + NAME = get_name('Power Lora Loader') + CATEGORY = get_category() + + @classmethod + def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring + return { + "required": { + "model": ("MODEL",), + "clip": ("CLIP",), + }, + # Since we will pass any number of loras in from the UI, this needs to always allow an + "optional": FlexibleOptionalInputType(any_type), + "hidden": {}, + } + + RETURN_TYPES = ("MODEL", "CLIP") + RETURN_NAMES = ("MODEL", "CLIP") + FUNCTION = "load_loras" + + def load_loras(self, model, clip, **kwargs): + """Loops over the provided loras in kwargs and applies valid ones.""" + for key, value in kwargs.items(): + key = key.upper() + if key.startswith('LORA_') and 'on' in value and 'lora' in value and 'strength' in value: + strength_model = value['strength'] + # If we just passed one strtength value, then use it for both, if we passed a strengthTwo + # as well, then our `strength` will be for the model, and `strengthTwo` for clip. + strength_clip = value['strengthTwo'] if 'strengthTwo' in value and value[ + 'strengthTwo'] is not None else strength_model + if value['on'] and (strength_model != 0 or strength_clip != 0): + lora = get_lora_by_filename(value['lora'], log_node=self.NAME) + if lora is not None: + model, clip = LoraLoader().load_lora(model, clip, lora, strength_model, strength_clip) + + return (model, clip) diff --git a/rgthree-comfy/py/power_prompt.py b/rgthree-comfy/py/power_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..17eb4f9766a01f56929622666bb98f3dd533f537 --- /dev/null +++ b/rgthree-comfy/py/power_prompt.py @@ -0,0 +1,95 @@ +import os + +from .log import log_node_warn, log_node_info, log_node_success + +from .constants import get_category, get_name +from .power_prompt_utils import get_and_strip_loras +from nodes import LoraLoader, CLIPTextEncode +import folder_paths + +NODE_NAME = get_name('Power Prompt') + + +class RgthreePowerPrompt: + + NAME = NODE_NAME + CATEGORY = get_category() + + @classmethod + def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring + SAVED_PROMPTS_FILES = folder_paths.get_filename_list('saved_prompts') + SAVED_PROMPTS_CONTENT = [] + for filename in SAVED_PROMPTS_FILES: + with open(folder_paths.get_full_path('saved_prompts', filename), 'r') as f: + SAVED_PROMPTS_CONTENT.append(f.read()) + return { + 'required': { + 'prompt': ('STRING', { + 'multiline': True + }), + }, + 'optional': { + "opt_model": ("MODEL",), + "opt_clip": ("CLIP",), + 'insert_lora': (['CHOOSE', 'DISABLE LORAS'] + + [os.path.splitext(x)[0] for x in folder_paths.get_filename_list('loras')],), + 'insert_embedding': ([ + 'CHOOSE', + ] + [os.path.splitext(x)[0] for x in folder_paths.get_filename_list('embeddings')],), + 'insert_saved': ([ + 'CHOOSE', + ] + SAVED_PROMPTS_FILES,), + }, + 'hidden': { + 'values_insert_saved': (['CHOOSE'] + SAVED_PROMPTS_CONTENT,), + } + } + + RETURN_TYPES = ( + 'CONDITIONING', + 'MODEL', + 'CLIP', + 'STRING', + ) + RETURN_NAMES = ( + 'CONDITIONING', + 'MODEL', + 'CLIP', + 'TEXT', + ) + FUNCTION = 'main' + + def main(self, + prompt, + opt_model=None, + opt_clip=None, + insert_lora=None, + insert_embedding=None, + insert_saved=None, + values_insert_saved=None): + if insert_lora == 'DISABLE LORAS': + prompt, loras, skipped, unfound = get_and_strip_loras(prompt, log_node=NODE_NAME, silent=True) + log_node_info( + NODE_NAME, + f'Disabling all found loras ({len(loras)}) and stripping lora tags for TEXT output.') + elif opt_model is not None and opt_clip is not None: + prompt, loras, skipped, unfound = get_and_strip_loras(prompt, log_node=NODE_NAME) + if len(loras) > 0: + for lora in loras: + opt_model, opt_clip = LoraLoader().load_lora(opt_model, opt_clip, lora['lora'], + lora['strength'], lora['strength']) + log_node_success(NODE_NAME, f'Loaded "{lora["lora"]}" from prompt') + log_node_info(NODE_NAME, f'{len(loras)} Loras processed; stripping tags for TEXT output.') + elif ' 1 and len(match[1]) else 1.0) + if strength == 0: + if not silent: + log_node_info(log_node, f'Skipping "{tag_path}" with strength of zero') + skipped_loras.append({'lora': tag_path, 'strength': strength}) + continue + + lora_path = get_lora_by_filename(tag_path, lora_paths, log_node=None if silent else log_node) + if lora_path is None: + unfound_loras.append({'lora': tag_path, 'strength': strength}) + continue + + loras.append({'lora': lora_path, 'strength': strength}) + + return (re.sub(pattern, '', prompt), loras, skipped_loras, unfound_loras) + + +# pylint: disable = too-many-return-statements, too-many-branches +def get_lora_by_filename(file_path, lora_paths=None, log_node=None): + """Returns a lora by filename, looking for exactl paths and then fuzzier matching.""" + lora_paths = lora_paths if lora_paths is not None else folder_paths.get_filename_list('loras') + + if file_path in lora_paths: + return file_path + + lora_paths_no_ext = [os.path.splitext(x)[0] for x in lora_paths] + + # See if we've entered the exact path, but without the extension + if file_path in lora_paths_no_ext: + found = lora_paths[lora_paths_no_ext.index(file_path)] + return found + + # Same check, but ensure file_path is without extension. + file_path_force_no_ext = os.path.splitext(file_path)[0] + if file_path_force_no_ext in lora_paths_no_ext: + found = lora_paths[lora_paths_no_ext.index(file_path_force_no_ext)] + return found + + # See if we passed just the name, without paths. + lora_filenames_only = [os.path.basename(x) for x in lora_paths] + if file_path in lora_filenames_only: + found = lora_paths[lora_filenames_only.index(file_path)] + if log_node is not None: + log_node_info(log_node, f'Matched Lora input "{file_path}" to "{found}".') + return found + + # Same, but force the input to be without paths + file_path_force_filename = os.path.basename(file_path) + lora_filenames_only = [os.path.basename(x) for x in lora_paths] + if file_path_force_filename in lora_filenames_only: + found = lora_paths[lora_filenames_only.index(file_path_force_filename)] + if log_node is not None: + log_node_info(log_node, f'Matched Lora input "{file_path}" to "{found}".') + return found + + # Check the filenames and without extension. + lora_filenames_and_no_ext = [os.path.splitext(os.path.basename(x))[0] for x in lora_paths] + if file_path in lora_filenames_and_no_ext: + found = lora_paths[lora_filenames_and_no_ext.index(file_path)] + if log_node is not None: + log_node_info(log_node, f'Matched Lora input "{file_path}" to "{found}".') + return found + + # And, one last forcing the input to be the same + file_path_force_filename_and_no_ext = os.path.splitext(os.path.basename(file_path))[0] + if file_path_force_filename_and_no_ext in lora_filenames_and_no_ext: + found = lora_paths[lora_filenames_and_no_ext.index(file_path_force_filename_and_no_ext)] + if log_node is not None: + log_node_info(log_node, f'Matched Lora input "{file_path}" to "{found}".') + return found + + # Finally, super fuzzy, we'll just check if the input exists in the path at all. + for index, lora_path in enumerate(lora_paths): + if file_path in lora_path: + found = lora_paths[index] + if log_node is not None: + log_node_warn(log_node, f'Fuzzy-matched Lora input "{file_path}" to "{found}".') + return found + + if log_node is not None: + log_node_warn(log_node, f'Lora "{file_path}" not found, skipping.') + + return None diff --git a/rgthree-comfy/py/rgthree_server.py b/rgthree-comfy/py/rgthree_server.py new file mode 100644 index 0000000000000000000000000000000000000000..af6e68ea9d47dd617b0a996414c9e89a55e344dc --- /dev/null +++ b/rgthree-comfy/py/rgthree_server.py @@ -0,0 +1,208 @@ +import os +import json +import re +import copy +import timeit +import asyncio + +from datetime import datetime + +from .utils import path_exists +from .utils_server import get_param, is_param_falsy +from .utils_info import delete_model_info, get_model_info, set_model_info_partial + +from server import PromptServer +from aiohttp import web + +import folder_paths + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +DIR_WEB = os.path.abspath(f'{THIS_DIR}/../web/') + +routes = PromptServer.instance.routes + + +def set_default_page_resources(path): + """ Sets up routes for handling static files under a path.""" + + @routes.get(f'/rgthree/{path}/{{file}}') + async def get_resource(request): + """ Returns a resource file. """ + return web.FileResponse(os.path.join(DIR_WEB, path, request.match_info['file'])) + + @routes.get(f'/rgthree/{path}/{{subdir}}/{{file}}') + async def get_resource_subdir(request): + """ Returns a resource file. """ + return web.FileResponse( + os.path.join(DIR_WEB, path, request.match_info['subdir'], request.match_info['file'])) + + +def set_default_page_routes(path): + """ Sets default path handling for a hosted rgthree page. """ + + @routes.get(f'/rgthree/{path}') + async def get_path_redir(request): + """ Redirects to the path adding a trailing slash. """ + raise web.HTTPFound(f'{request.path}/') + + @routes.get(f'/rgthree/{path}/') + async def get_path_index(request): + """ Handles the page's index loading. """ + html = '' + with open(os.path.join(DIR_WEB, path, 'index.html'), 'r', encoding='UTF-8') as file: + html = file.read() + return web.Response(text=html, content_type='text/html') + + set_default_page_resources(path) + + +# Sometimes other pages (link_fixer, etc.) may want to import JS from the comfyui +# directory. To allows TS to resolve like '../comfyui/file.js', we'll also resolve any module HTTP +# to these routes. +set_default_page_resources("comfyui") +set_default_page_resources("common") + +set_default_page_routes("link_fixer") + +# Configuration +from .config import RGTHREE_CONFIG, set_user_config + + +@routes.get('/rgthree/config.js') +def api_get_user_config_file(request): + """ Returns the user configuration as a jsavascript file. """ + data_str = json.dumps(RGTHREE_CONFIG, sort_keys=True, indent=2, separators=(",", ": ")) + text = f'export const rgthreeConfig = {data_str}' + return web.Response(text=text, content_type='application/javascript') + + +@routes.get('/rgthree/api/config') +def api_get_user_config(request): + """ Returns the user configuration. """ + return web.json_response(RGTHREE_CONFIG) + + +@routes.post('/rgthree/api/config') +async def api_set_user_config(request): + """ Returns the user configuration. """ + post = await request.post() + data = json.loads(post.get("json")) + set_user_config(data) + return web.json_response({"status": "ok"}) + + +# General + + +@routes.get('/rgthree/api/loras') +async def api_get_loras(request): + """ Returns a list of loras user configuration. """ + data = folder_paths.get_filename_list("loras") + return web.json_response(list(data)) + + +@routes.get('/rgthree/api/loras/info') +async def api_get_loras_info(request): + """ Returns a list loras info; either all or a single if provided a 'file' param. """ + lora_file = get_param(request, 'file') + maybe_fetch_metadata = lora_file is not None + if not is_param_falsy(request, 'light'): + maybe_fetch_metadata = False + api_response = await get_loras_info_response(request, maybe_fetch_metadata=maybe_fetch_metadata) + return web.json_response(api_response) + + +@routes.get('/rgthree/api/loras/info/clear') +async def delete_lora_info(request): + """Clears lora info from the filesystem for the provided file.""" + api_response = {'status': 200} + lora_file = get_param(request, 'file') + del_info = not is_param_falsy(request, 'del_info') + del_metadata = not is_param_falsy(request, 'del_metadata') + del_civitai = not is_param_falsy(request, 'del_civitai') + if lora_file is None: + api_response['status'] = '404' + api_response['error'] = 'No Lora file provided' + elif lora_file == "ALL": # Force the user to supply file=ALL to trigger all clearing. + lora_files = folder_paths.get_filename_list("loras") + for lora_file in lora_files: + await delete_model_info(lora_file, del_info=del_info, del_metadata=del_metadata, del_civitai=del_civitai) + else: + await delete_model_info(lora_file, del_info=del_info, del_metadata=del_metadata, del_civitai=del_civitai) + return web.json_response(api_response) + + +@routes.get('/rgthree/api/loras/info/refresh') +async def refresh_get_loras_info(request): + """ Refreshes lora info; either all or a single if provided a 'file' param. """ + api_response = await get_loras_info_response(request, + maybe_fetch_civitai=True, + maybe_fetch_metadata=True) + return web.json_response(api_response) + + +async def get_loras_info_response(request, maybe_fetch_civitai=False, maybe_fetch_metadata=False): + """Gets lora info for all or a single lora""" + api_response = {'status': 200} + lora_file = get_param(request, 'file') + light = not is_param_falsy(request, 'light') + if lora_file is not None: + info_data = await get_model_info(lora_file, + maybe_fetch_civitai=maybe_fetch_civitai, + maybe_fetch_metadata=maybe_fetch_metadata, + light=light) + if info_data is None: + api_response['status'] = '404' + api_response['error'] = 'No Lora found at path' + else: + api_response['data'] = info_data + else: + api_response['data'] = [] + lora_files = folder_paths.get_filename_list("loras") + for lora_file in lora_files: + info_data = await get_model_info(lora_file, + maybe_fetch_civitai=maybe_fetch_civitai, + maybe_fetch_metadata=maybe_fetch_metadata, + light=light) + api_response['data'].append(info_data) + return api_response + + +@routes.post('/rgthree/api/loras/info') +async def api_save_lora_data(request): + """Saves data to a lora by name. """ + api_response = {'status': 200} + lora_file = get_param(request, 'file') + if lora_file is None: + api_response['status'] = '404' + api_response['error'] = 'No Lora found at path' + else: + post = await request.post() + await set_model_info_partial(lora_file, json.loads(post.get("json"))) + info_data = await get_model_info(lora_file) + api_response['data'] = info_data + return web.json_response(api_response) + + +@routes.get('/rgthree/api/loras/img') +async def api_get_loras_info_img(request): + """ Returns an image response if one exists for the lora. """ + lora_file = get_param(request, 'file') + lora_path = folder_paths.get_full_path("loras", lora_file) + if not path_exists(lora_path): + lora_path = os.path.abspath(lora_path) + + img_path = None + for ext in ['jpg', 'png', 'jpeg']: + try_path = f'{os.path.splitext(lora_path)[0]}.{ext}' + if path_exists(try_path): + img_path = try_path + break + + if not path_exists(img_path): + api_response = {} + api_response['status'] = '404' + api_response['error'] = 'No Lora found at path' + return web.json_response(api_response) + + return web.FileResponse(img_path) diff --git a/rgthree-comfy/py/sdxl_empty_latent_image.py b/rgthree-comfy/py/sdxl_empty_latent_image.py new file mode 100644 index 0000000000000000000000000000000000000000..87b86ce4ae23763615ea0e45cb6fa7cbdcafb0a7 --- /dev/null +++ b/rgthree-comfy/py/sdxl_empty_latent_image.py @@ -0,0 +1,63 @@ +from nodes import EmptyLatentImage +from .constants import get_category, get_name + + +class RgthreeSDXLEmptyLatentImage: + + NAME = get_name('SDXL Empty Latent Image') + CATEGORY = get_category() + + @classmethod + def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring + return { + "required": { + "dimensions": ( + [ + # 'Custom', + '1536 x 640 (landscape)', + '1344 x 768 (landscape)', + '1216 x 832 (landscape)', + '1152 x 896 (landscape)', + '1024 x 1024 (square)', + ' 896 x 1152 (portrait)', + ' 832 x 1216 (portrait)', + ' 768 x 1344 (portrait)', + ' 640 x 1536 (portrait)', + ], + { + "default": '1024 x 1024 (square)' + }), + "clip_scale": ("FLOAT", { + "default": 2.0, + "min": 1.0, + "max": 10.0, + "step": .5 + }), + "batch_size": ("INT", { + "default": 1, + "min": 1, + "max": 64 + }), + }, + # "optional": { + # "custom_width": ("INT", {"min": 1, "max": MAX_RESOLUTION, "step": 64}), + # "custom_height": ("INT", {"min": 1, "max": MAX_RESOLUTION, "step": 64}), + # } + } + + RETURN_TYPES = ("LATENT", "INT", "INT") + RETURN_NAMES = ("LATENT", "CLIP_WIDTH", "CLIP_HEIGHT") + FUNCTION = "generate" + + def generate(self, dimensions, clip_scale, batch_size): + """Generates the latent and exposes the clip_width and clip_height""" + if True: + result = [x.strip() for x in dimensions.split('x')] + width = int(result[0]) + height = int(result[1].split(' ')[0]) + latent = EmptyLatentImage().generate(width, height, batch_size)[0] + return ( + latent, + int(width * clip_scale), + int(height * clip_scale), + ) diff --git a/rgthree-comfy/py/sdxl_power_prompt_postive.py b/rgthree-comfy/py/sdxl_power_prompt_postive.py new file mode 100644 index 0000000000000000000000000000000000000000..68089180de87fc681e727c15ccc9d8879f873abb --- /dev/null +++ b/rgthree-comfy/py/sdxl_power_prompt_postive.py @@ -0,0 +1,168 @@ +import os +import re +from nodes import MAX_RESOLUTION +from comfy_extras.nodes_clip_sdxl import CLIPTextEncodeSDXL + +from .log import log_node_warn, log_node_info, log_node_success +from .constants import get_category, get_name +from .power_prompt_utils import get_and_strip_loras +from nodes import LoraLoader, CLIPTextEncode +import folder_paths + +NODE_NAME = get_name('SDXL Power Prompt - Positive') + + +class RgthreeSDXLPowerPromptPositive: + """The Power Prompt for positive conditioning.""" + + NAME = NODE_NAME + CATEGORY = get_category() + + @classmethod + def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring + SAVED_PROMPTS_FILES = folder_paths.get_filename_list('saved_prompts') + SAVED_PROMPTS_CONTENT = [] + for filename in SAVED_PROMPTS_FILES: + with open(folder_paths.get_full_path('saved_prompts', filename), 'r') as f: + SAVED_PROMPTS_CONTENT.append(f.read()) + return { + 'required': { + 'prompt_g': ('STRING', { + 'multiline': True + }), + 'prompt_l': ('STRING', { + 'multiline': True + }), + }, + 'optional': { + "opt_model": ("MODEL",), + "opt_clip": ("CLIP",), + "opt_clip_width": ("INT", { + "forceInput": True, + "default": 1024.0, + "min": 0, + "max": MAX_RESOLUTION + }), + "opt_clip_height": ("INT", { + "forceInput": True, + "default": 1024.0, + "min": 0, + "max": MAX_RESOLUTION + }), + 'insert_lora': (['CHOOSE', 'DISABLE LORAS'] + + [os.path.splitext(x)[0] for x in folder_paths.get_filename_list('loras')],), + 'insert_embedding': ([ + 'CHOOSE', + ] + [os.path.splitext(x)[0] for x in folder_paths.get_filename_list('embeddings')],), + 'insert_saved': ([ + 'CHOOSE', + ] + SAVED_PROMPTS_FILES,), + # We'll hide these in the UI for now. + "target_width": ("INT", { + "default": -1, + "min": -1, + "max": MAX_RESOLUTION + }), + "target_height": ("INT", { + "default": -1, + "min": -1, + "max": MAX_RESOLUTION + }), + "crop_width": ("INT", { + "default": -1, + "min": -1, + "max": MAX_RESOLUTION + }), + "crop_height": ("INT", { + "default": -1, + "min": -1, + "max": MAX_RESOLUTION + }), + }, + 'hidden': { + 'values_insert_saved': (['CHOOSE'] + SAVED_PROMPTS_CONTENT,), + } + } + + RETURN_TYPES = ('CONDITIONING', 'MODEL', 'CLIP', 'STRING', 'STRING') + RETURN_NAMES = ('CONDITIONING', 'MODEL', 'CLIP', 'TEXT_G', 'TEXT_L') + FUNCTION = 'main' + + def main(self, + prompt_g, + prompt_l, + opt_model=None, + opt_clip=None, + opt_clip_width=None, + opt_clip_height=None, + insert_lora=None, + insert_embedding=None, + insert_saved=None, + target_width=-1, + target_height=-1, + crop_width=-1, + crop_height=-1, + values_insert_saved=None): + + if insert_lora == 'DISABLE LORAS': + prompt_g, loras_g, _skipped, _unfound = get_and_strip_loras(prompt_g, + True, + log_node=self.NAME) + prompt_l, loras_l, _skipped, _unfound = get_and_strip_loras(prompt_l, + True, + log_node=self.NAME) + loras = loras_g + loras_l + log_node_info( + NODE_NAME, + f'Disabling all found loras ({len(loras)}) and stripping lora tags for TEXT output.') + elif opt_model is not None and opt_clip is not None: + prompt_g, loras_g, _skipped, _unfound = get_and_strip_loras(prompt_g, log_node=self.NAME) + prompt_l, loras_l, _skipped, _unfound = get_and_strip_loras(prompt_l, log_node=self.NAME) + loras = loras_g + loras_l + if len(loras) > 0: + for lora in loras: + opt_model, opt_clip = LoraLoader().load_lora(opt_model, opt_clip, lora['lora'], + lora['strength'], lora['strength']) + log_node_success(NODE_NAME, f'Loaded "{lora["lora"]}" from prompt') + log_node_info(NODE_NAME, f'{len(loras)} Loras processed; stripping tags for TEXT output.') + elif ' 0 else opt_clip_width + target_height = target_height if target_height and target_height > 0 else opt_clip_height + crop_width = crop_width if crop_width and crop_width > 0 else 0 + crop_height = crop_height if crop_height and crop_height > 0 else 0 + conditioning = CLIPTextEncodeSDXL().encode(opt_clip, opt_clip_width, opt_clip_height, + crop_width, crop_height, target_width, + target_height, prompt_g, prompt_l)[0] + else: + # If we got an opt_clip, but no clip_width or _height, then use normal CLIPTextEncode + log_node_info( + self.NAME, + 'CLIP supplied, but not CLIP_WIDTH and CLIP_HEIGHT. Text encoding will use standard encoding with prompt_g and prompt_l concatenated.' + ) + conditioning = CLIPTextEncode().encode( + opt_clip, f'{prompt_g if prompt_g else ""}\n{prompt_l if prompt_l else ""}')[0] + return conditioning diff --git a/rgthree-comfy/py/sdxl_power_prompt_simple.py b/rgthree-comfy/py/sdxl_power_prompt_simple.py new file mode 100644 index 0000000000000000000000000000000000000000..fae8b6826c02abede334dcdfc6f51fbb44ad1d8d --- /dev/null +++ b/rgthree-comfy/py/sdxl_power_prompt_simple.py @@ -0,0 +1,106 @@ +"""A simpler SDXL Power Prompt that doesn't load Loras, like for negative.""" +import os +import re +import folder_paths +from nodes import MAX_RESOLUTION, LoraLoader +from comfy_extras.nodes_clip_sdxl import CLIPTextEncodeSDXL +from .sdxl_power_prompt_postive import RgthreeSDXLPowerPromptPositive + +from .log import log_node_warn, log_node_info, log_node_success + +from .constants import get_category, get_name + +NODE_NAME = get_name('SDXL Power Prompt - Simple / Negative') + + +class RgthreeSDXLPowerPromptSimple(RgthreeSDXLPowerPromptPositive): + """A simpler SDXL Power Prompt that doesn't handle Loras.""" + + NAME = NODE_NAME + CATEGORY = get_category() + + @classmethod + def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring + saved_prompts_files = folder_paths.get_filename_list('saved_prompts') + saved_promptes_content = [] + for fname in saved_prompts_files: + with open(folder_paths.get_full_path('saved_prompts', fname), 'r', encoding="utf-8") as file: + saved_promptes_content.append(file.read()) + + return { + 'required': { + 'prompt_g': ('STRING', { + 'multiline': True + }), + 'prompt_l': ('STRING', { + 'multiline': True + }), + }, + 'optional': { + "opt_clip": ("CLIP",), + "opt_clip_width": ("INT", { + "forceInput": True, + "default": 1024.0, + "min": 0, + "max": MAX_RESOLUTION + }), + "opt_clip_height": ("INT", { + "forceInput": True, + "default": 1024.0, + "min": 0, + "max": MAX_RESOLUTION + }), + 'insert_embedding': ([ + 'CHOOSE', + ] + [os.path.splitext(x)[0] for x in folder_paths.get_filename_list('embeddings')],), + 'insert_saved': ([ + 'CHOOSE', + ] + saved_prompts_files,), + # We'll hide these in the UI for now. + "target_width": ("INT", { + "default": -1, + "min": -1, + "max": MAX_RESOLUTION + }), + "target_height": ("INT", { + "default": -1, + "min": -1, + "max": MAX_RESOLUTION + }), + "crop_width": ("INT", { + "default": -1, + "min": -1, + "max": MAX_RESOLUTION + }), + "crop_height": ("INT", { + "default": -1, + "min": -1, + "max": MAX_RESOLUTION + }), + }, + 'hidden': { + 'values_insert_saved': (['CHOOSE'] + saved_promptes_content,), + } + } + + RETURN_TYPES = ('CONDITIONING', 'STRING', 'STRING') + RETURN_NAMES = ('CONDITIONING', 'TEXT_G', 'TEXT_L') + FUNCTION = 'main' + + def main(self, + prompt_g, + prompt_l, + opt_clip=None, + opt_clip_width=None, + opt_clip_height=None, + insert_embedding=None, + insert_saved=None, + target_width=-1, + target_height=-1, + crop_width=-1, + crop_height=-1, + values_insert_saved=None): + + conditioning = self.get_conditioning(prompt_g, prompt_l, opt_clip, opt_clip_width, + opt_clip_height, target_width, target_height, crop_width, crop_height) + return (conditioning, prompt_g, prompt_l) diff --git a/rgthree-comfy/py/seed.py b/rgthree-comfy/py/seed.py new file mode 100644 index 0000000000000000000000000000000000000000..061cbfedec80bfed21cb0377626ea459ea2dabd3 --- /dev/null +++ b/rgthree-comfy/py/seed.py @@ -0,0 +1,115 @@ +"""See node.""" +import random +from datetime import datetime + +from .constants import get_category, get_name +from .log import log_node_warn, log_node_info + +# Some extension must be setting a seed as server-generated seeds were not random. We'll set a new +# seed and use that state going forward. +initial_random_state = random.getstate() +random.seed(datetime.now().timestamp()) +rgthree_seed_random_state = random.getstate() +random.setstate(initial_random_state) + + +def new_random_seed(): + """ Gets a new random seed from the rgthree_seed_random_state and resetting the previous state.""" + global rgthree_seed_random_state + prev_random_state = random.getstate() + random.setstate(rgthree_seed_random_state) + seed = random.randint(1, 1125899906842624) + rgthree_seed_random_state = random.getstate() + random.setstate(prev_random_state) + return seed + + +class RgthreeSeed: + """Seed node.""" + + NAME = get_name('Seed') + CATEGORY = get_category() + + @classmethod + def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring + return { + "required": { + "seed": ("INT", { + "default": 0, + "min": -1125899906842624, + "max": 1125899906842624 + }), + }, + "hidden": { + "prompt": "PROMPT", + "extra_pnginfo": "EXTRA_PNGINFO", + "unique_id": "UNIQUE_ID", + }, + } + + RETURN_TYPES = ("INT",) + RETURN_NAMES = ("SEED",) + FUNCTION = "main" + + def main(self, seed=0, prompt=None, extra_pnginfo=None, unique_id=None): + """Returns the passed seed on execution.""" + + # We generate random seeds on the frontend in the seed node before sending the workflow in for + # many reasons. However, if we want to use this in an API call without changing the seed before + # sending, then users _could_ pass in "-1" and get a random seed used and added to the metadata. + # Though, this should likely be discouraged for several reasons (thus, a lot of logging). + if seed in (-1, -2, -3): + log_node_warn(self.NAME, + f'Got "{seed}" as passed seed. ' + + 'This shouldn\'t happen when queueing from the ComfyUI frontend.', + msg_color="YELLOW") + if seed in (-2, -3): + log_node_warn(self.NAME, + f'Cannot {"increment" if seed == -2 else "decrement"} seed from ' + + 'server, but will generate a new random seed.', + msg_color="YELLOW") + + original_seed = seed + seed = new_random_seed() + log_node_info(self.NAME, f'Server-generated random seed {seed} and saving to workflow.') + log_node_warn( + self.NAME, + 'NOTE: Re-queues passing in "{seed}" and server-generated random seed won\'t be cached.', + msg_color="YELLOW") + + if unique_id is None: + log_node_warn( + self.NAME, 'Cannot save server-generated seed to image metadata because ' + + 'the node\'s id was not provided.') + else: + if extra_pnginfo is None: + log_node_warn( + self.NAME, 'Cannot save server-generated seed to image workflow ' + + 'metadata because workflow was not provided.') + else: + workflow_node = next( + (x for x in extra_pnginfo['workflow']['nodes'] if x['id'] == int(unique_id)), None) + if workflow_node is None or 'widgets_values' not in workflow_node: + log_node_warn( + self.NAME, 'Cannot save server-generated seed to image workflow ' + + 'metadata because node was not found in the provided workflow.') + else: + for index, widget_value in enumerate(workflow_node['widgets_values']): + if widget_value == original_seed: + workflow_node['widgets_values'][index] = seed + + if prompt is None: + log_node_warn( + self.NAME, 'Cannot save server-generated seed to image API prompt ' + + 'metadata because prompt was not provided.') + else: + prompt_node = prompt[str(unique_id)] + if prompt_node is None or 'inputs' not in prompt_node or 'seed' not in prompt_node[ + 'inputs']: + log_node_warn( + self.NAME, 'Cannot save server-generated seed to image workflow ' + + 'metadata because node was not found in the provided workflow.') + else: + prompt_node['inputs']['seed'] = seed + + return (seed,) diff --git a/rgthree-comfy/py/utils.py b/rgthree-comfy/py/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..20c01967ab1f7182ef08def4364897bbfabea862 --- /dev/null +++ b/rgthree-comfy/py/utils.py @@ -0,0 +1,124 @@ +import json +import os +import re + + +class AnyType(str): + """A special class that is always equal in not equal comparisons. Credit to pythongosssss""" + + def __ne__(self, __value: object) -> bool: + return False + +class FlexibleOptionalInputType(dict): + """A special class to make flexible nodes that pass data to our python handlers. + + Enables both flexible/dynamic input types (like for Any Switch) or a dynamic number of inputs + (like for Any Switch, Context Switch, Context Merge, Power Lora Loader, etc). + + Note, for ComfyUI, all that's needed is the `__contains__` override below, which tells ComfyUI + that our node will handle the input, regardless of what it is. + + However, with https://github.com/comfyanonymous/ComfyUI/pull/2666 a large change would occur + requiring more details on the input itself. There, we need to return a list/tuple where the first + item is the type. This can be a real type, or use the AnyType for additional flexibility. + + This should be forwards compatible unless more changes occur in the PR. + """ + def __init__(self, type): + self.type = type + + def __getitem__(self, key): + return (self.type, ) + + def __contains__(self, key): + return True + + +any_type = AnyType("*") + + +def is_dict_value_falsy(data: dict, dict_key: str): + """ Checks if a dict value is falsy.""" + val = get_dict_value(data, dict_key) + return not val + + +def get_dict_value(data: dict, dict_key: str, default=None): + """ Gets a deeply nested value given a dot-delimited key.""" + keys = dict_key.split('.') + key = keys.pop(0) if len(keys) > 0 else None + found = data[key] if key in data else None + if found is not None and len(keys) > 0: + return get_dict_value(found, '.'.join(keys), default) + return found if found is not None else default + + +def set_dict_value(data: dict, dict_key: str, value, create_missing_objects=True): + """ Sets a deeply nested value given a dot-delimited key.""" + keys = dict_key.split('.') + key = keys.pop(0) if len(keys) > 0 else None + if key not in data: + if create_missing_objects == False: + return + data[key] = {} + if len(keys) == 0: + data[key] = value + else: + set_dict_value(data[key], '.'.join(keys), value, create_missing_objects) + + return data + + +def dict_has_key(data: dict, dict_key): + """ Checks if a dict has a deeply nested dot-delimited key.""" + keys = dict_key.split('.') + key = keys.pop(0) if len(keys) > 0 else None + if key is None or key not in data: + return False + if len(keys) == 0: + return True + return dict_has_key(data[key], '.'.join(keys)) + + +def load_json_file(file: str, default=None): + """Reads a json file and returns the json dict, stripping out "//" comments first.""" + if path_exists(file): + with open(file, 'r', encoding='UTF-8') as file: + config = file.read() + try: + return json.loads(config) + except json.decoder.JSONDecodeError: + try: + config = re.sub(r"^\s*//\s.*", "", config, flags=re.MULTILINE) + return json.loads(config) + except json.decoder.JSONDecodeError: + try: + config = re.sub(r"(?:^|\s)//.*", "", config, flags=re.MULTILINE) + return json.loads(config) + except json.decoder.JSONDecodeError: + pass + return default + + +def save_json_file(file_path: str, data: dict): + """Saves a json file.""" + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, 'w+', encoding='UTF-8') as file: + json.dump(data, file, sort_keys=False, indent=2, separators=(",", ": ")) + +def path_exists(path): + """Checks if a path exists, accepting None type.""" + if path is not None: + return os.path.exists(path) + return False + + +class ByPassTypeTuple(tuple): + """A special class that will return additional "AnyType" strings beyond defined values. + Credit to Trung0246 + """ + + def __getitem__(self, index): + if index > len(self) - 1: + return AnyType("*") + return super().__getitem__(index) diff --git a/rgthree-comfy/py/utils_info.py b/rgthree-comfy/py/utils_info.py new file mode 100644 index 0000000000000000000000000000000000000000..6996ccc76d434e293b2fbdfc976df9e1b49a2855 --- /dev/null +++ b/rgthree-comfy/py/utils_info.py @@ -0,0 +1,421 @@ +import hashlib +import requests +import json +import torch +import re +import os +import copy + +from datetime import datetime + +from .utils import get_dict_value, load_json_file, path_exists, save_json_file +from .utils_userdata import read_userdata_json, save_userdata_json, delete_userdata_file + +import folder_paths +from server import PromptServer + + +def _get_info_cache_file(data_type: str, file_hash: str): + return f'info/{file_hash}.{data_type}.json' + + +async def delete_model_info(file: str, + model_type="loras", + del_info=True, + del_metadata=True, + del_civitai=True): + """Delete the info json, and the civitai & metadata caches.""" + file_path = get_folder_path(file, model_type) + if file_path is None: + return + if del_info: + try_info_path = f'{file_path}.rgthree-info.json' + if os.path.isfile(try_info_path): + os.remove(try_info_path) + if del_civitai or del_metadata: + file_hash = _get_sha256_hash(file_path) + if del_civitai: + json_file_path = _get_info_cache_file(file_hash, 'civitai') + delete_userdata_file(json_file_path) + if del_metadata: + json_file_path = _get_info_cache_file(file_hash, 'metadata') + delete_userdata_file(json_file_path) + + +async def get_model_info(file: str, + model_type="loras", + default=None, + maybe_fetch_civitai=False, + force_fetch_civitai=False, + maybe_fetch_metadata=False, + force_fetch_metadata=False, + light=False): + """Compiles a model info given a stored file next to the model, and/or metadata/civitai.""" + + file_path = get_folder_path(file, model_type) + if file_path is None: + return default + + info_data = {} + should_save = False + # Try to load a rgthree-info.json file next to the file. + try_info_path = f'{file_path}.rgthree-info.json' + if path_exists(try_info_path): + info_data = load_json_file(try_info_path) + + if 'file' not in info_data: + info_data['file'] = file + should_save = True + if 'path' not in info_data: + info_data['path'] = file_path + should_save = True + + # Check if we have an image next to the file and, if so, add it to the front of the images + # (if it isn't already). + img_next_to_file = None + for ext in ['jpg', 'png', 'jpeg']: + try_path = f'{os.path.splitext(file_path)[0]}.{ext}' + if path_exists(try_path): + img_next_to_file = try_path + break + + if 'images' not in info_data: + info_data['images'] = [] + should_save = True + + if img_next_to_file: + img_next_to_file_url = f'/rgthree/api/loras/img?file={file}' + if len(info_data['images']) == 0 or info_data['images'][0]['url'] != img_next_to_file_url: + info_data['images'].insert(0, {'url': img_next_to_file_url}) + should_save = True + + # If we just want light data then bail now with just existing data, plus file, path and img if + # next to the file. + if light and not maybe_fetch_metadata and not force_fetch_metadata and not maybe_fetch_civitai and not force_fetch_civitai: + return info_data + + if 'raw' not in info_data: + info_data['raw'] = {} + should_save = True + + should_save = _update_data(info_data) or should_save + + should_fetch_civitai = force_fetch_civitai is True or (maybe_fetch_civitai is True and + 'civitai' not in info_data['raw']) + should_fetch_metadata = force_fetch_metadata is True or (maybe_fetch_metadata is True and + 'metadata' not in info_data['raw']) + + if should_fetch_metadata: + data_meta = _get_model_metadata(file, + model_type=model_type, + default={}, + refresh=force_fetch_metadata) + should_save = _merge_metadata(info_data, data_meta) or should_save + + if should_fetch_civitai: + data_civitai = _get_model_civitai_data(file, + model_type=model_type, + default={}, + refresh=force_fetch_civitai) + should_save = _merge_civitai_data(info_data, data_civitai) or should_save + + if 'sha256' not in info_data: + file_hash = _get_sha256_hash(file_path) + if file_hash is not None: + info_data['sha256'] = file_hash + should_save = True + + if should_save: + if 'trainedWords' in info_data: + # Sort by count; if it doesn't exist, then assume it's a top item from civitai or elsewhere. + info_data['trainedWords'] = sorted(info_data['trainedWords'], + key=lambda w: w['count'] if 'count' in w else 99999, + reverse=True) + save_model_info(file, info_data, model_type=model_type) + + # If we're saving, then the UI is likely waiting to see if the refreshed data is coming in. + await PromptServer.instance.send("rgthree-refreshed-lora-info", {"data": info_data}) + + return info_data + + +def _update_data(info_data: dict) -> bool: + """Ports old data to new data if necessary.""" + should_save = False + # If we have "triggerWords" then move them over to "trainedWords" + if 'triggerWords' in info_data and len(info_data['triggerWords']) > 0: + civitai_words = ','.join((get_dict_value(info_data, 'raw.civitai.triggerWords', default=[]) + + get_dict_value(info_data, 'raw.civitai.trainedWords', default=[]))) + if 'trainedWords' not in info_data: + info_data['trainedWords'] = [] + for trigger_word in info_data['triggerWords']: + word_data = next((data for data in info_data['trainedWords'] if data['word'] == trigger_word), + None) + if word_data is None: + word_data = {'word': trigger_word} + info_data['trainedWords'].append(word_data) + if trigger_word in civitai_words: + word_data['civitai'] = True + else: + word_data['user'] = True + + del info_data['triggerWords'] + should_save = True + return should_save + + +def _merge_metadata(info_data: dict, data_meta: dict) -> bool: + """Returns true if data was saved.""" + should_save = False + + base_model_file = get_dict_value(data_meta, 'ss_sd_model_name', None) + if base_model_file: + info_data['baseModelFile'] = base_model_file + + # Loop over metadata tags + trained_words = {} + if 'ss_tag_frequency' in data_meta and isinstance(data_meta['ss_tag_frequency'], dict): + for bucket_value in data_meta['ss_tag_frequency'].values(): + if isinstance(bucket_value, dict): + for tag, count in bucket_value.items(): + if tag not in trained_words: + trained_words[tag] = {'word': tag, 'count': 0, 'metadata': True} + trained_words[tag]['count'] = trained_words[tag]['count'] + count + + if 'trainedWords' not in info_data: + info_data['trainedWords'] = list(trained_words.values()) + should_save = True + else: + # We can't merge, because the list may have other data, like it's part of civitaidata. + merged_dict = {} + for existing_word_data in info_data['trainedWords']: + merged_dict[existing_word_data['word']] = existing_word_data + for new_key, new_word_data in trained_words.items(): + if new_key not in merged_dict: + merged_dict[new_key] = {} + merged_dict[new_key] = {**merged_dict[new_key], **new_word_data} + info_data['trainedWords'] = list(merged_dict.values()) + should_save = True + + # trained_words = list(trained_words.values()) + # info_data['meta_trained_words'] = trained_words + info_data['raw']['metadata'] = data_meta + should_save = True + + if 'sha256' not in info_data and '_sha256' in data_meta: + info_data['sha256'] = data_meta['_sha256'] + should_save = True + + return should_save + + +def _merge_civitai_data(info_data: dict, data_civitai: dict) -> bool: + """Returns true if data was saved.""" + should_save = False + + if 'name' not in info_data: + info_data['name'] = get_dict_value(data_civitai, 'model.name', '') + should_save = True + version_name = get_dict_value(data_civitai, 'name') + if version_name is not None: + info_data['name'] += f' - {version_name}' + + if 'type' not in info_data: + info_data['type'] = get_dict_value(data_civitai, 'model.type') + should_save = True + if 'baseModel' not in info_data: + info_data['baseModel'] = get_dict_value(data_civitai, 'baseModel') + should_save = True + + # We always want to merge triggerword. + civitai_trigger = get_dict_value(data_civitai, 'triggerWords', default=[]) + civitai_trained = get_dict_value(data_civitai, 'trainedWords', default=[]) + civitai_words = ','.join(civitai_trigger + civitai_trained) + if civitai_words: + civitai_words = re.sub(r"\s*,\s*", ",", civitai_words) + civitai_words = re.sub(r",+", ",", civitai_words) + civitai_words = re.sub(r"^,", "", civitai_words) + civitai_words = re.sub(r",$", "", civitai_words) + if civitai_words: + civitai_words = civitai_words.split(',') + if 'trainedWords' not in info_data: + info_data['trainedWords'] = [] + for trigger_word in civitai_words: + word_data = next( + (data for data in info_data['trainedWords'] if data['word'] == trigger_word), None) + if word_data is None: + word_data = {'word': trigger_word} + info_data['trainedWords'].append(word_data) + word_data['civitai'] = True + + if 'sha256' not in info_data: + info_data['sha256'] = data_civitai['_sha256'] + should_save = True + + if 'modelId' in data_civitai: + info_data['links'] = info_data['links'] if 'links' in info_data else [] + civitai_link = f'https://civitai.com/models/{get_dict_value(data_civitai, "modelId")}' + if get_dict_value(data_civitai, "id"): + civitai_link += f'?modelVersionId={get_dict_value(data_civitai, "id")}' + info_data['links'].append(civitai_link) + info_data['links'].append(data_civitai['_civitai_api']) + should_save = True + + # Take images from civitai + if 'images' in data_civitai: + info_data_image_urls = list(map(lambda i: i['url'] + if 'url' in i else None, info_data['images'])) + for img in data_civitai['images']: + img_url = get_dict_value(img, 'url') + if img_url is not None and img_url not in info_data_image_urls: + img_id = os.path.splitext(os.path.basename(img_url))[0] if img_url is not None else None + img_data = { + 'url': img_url, + 'civitaiUrl': f'https://civitai.com/images/{img_id}' if img_id is not None else None, + 'width': get_dict_value(img, 'width'), + 'height': get_dict_value(img, 'height'), + 'type': get_dict_value(img, 'type'), + 'nsfwLevel': get_dict_value(img, 'nsfwLevel'), + 'seed': get_dict_value(img, 'meta.seed'), + 'positive': get_dict_value(img, 'meta.prompt'), + 'negative': get_dict_value(img, 'meta.negativePrompt'), + 'steps': get_dict_value(img, 'meta.steps'), + 'sampler': get_dict_value(img, 'meta.sampler'), + 'cfg': get_dict_value(img, 'meta.cfgScale'), + 'model': get_dict_value(img, 'meta.Model'), + 'resources': get_dict_value(img, 'meta.resources'), + } + info_data['images'].append(img_data) + should_save = True + + # The raw data + if 'civitai' not in info_data['raw']: + info_data['raw']['civitai'] = data_civitai + should_save = True + + return should_save + + +def _get_model_civitai_data(file: str, model_type="loras", default=None, refresh=False): + """Gets the civitai data, either cached from the user directory, or from civitai api.""" + file_hash = _get_sha256_hash(get_folder_path(file, model_type)) + if file_hash is None: + return None + + json_file_path = _get_info_cache_file(file_hash, 'civitai') + + api_url = f'https://civitai.com/api/v1/model-versions/by-hash/{file_hash}' + file_data = read_userdata_json(json_file_path) + if file_data is None or refresh is True: + try: + response = requests.get(api_url, timeout=5000) + data = response.json() + save_userdata_json(json_file_path, { + 'url': api_url, + 'timestamp': datetime.now().timestamp(), + 'response': data + }) + file_data = read_userdata_json(json_file_path) + except requests.exceptions.RequestException as e: # This is the correct syntax + print(e) + response = file_data['response'] if file_data is not None and 'response' in file_data else None + if response is not None: + response['_sha256'] = file_hash + response['_civitai_api'] = api_url + return response if response is not None else default + + +def _get_model_metadata(file: str, model_type="loras", default=None, refresh=False): + """Gets the metadata from the file itself.""" + file_path = get_folder_path(file, model_type) + file_hash = _get_sha256_hash(file_path) + if file_hash is None: + return default + + json_file_path = _get_info_cache_file(file_hash, 'metadata') + + file_data = read_userdata_json(json_file_path) + if file_data is None or refresh is True: + data = _read_file_metadata_from_header(file_path) + if data is not None: + file_data = {'url': file, 'timestamp': datetime.now().timestamp(), 'response': data} + save_userdata_json(json_file_path, file_data) + response = file_data['response'] if file_data is not None and 'response' in file_data else None + if response is not None: + response['_sha256'] = file_hash + return response if response is not None else default + + +def _read_file_metadata_from_header(file_path: str) -> dict: + """Reads the file's header and returns a JSON dict metdata if available.""" + data = None + try: + if file_path.endswith('.safetensors'): + with open(file_path, "rb") as file: + # https://github.com/huggingface/safetensors#format + # 8 bytes: N, an unsigned little-endian 64-bit integer, containing the size of the header + header_size = int.from_bytes(file.read(8), "little", signed=False) + + if header_size <= 0: + raise BufferError("Invalid header size") + + header = file.read(header_size) + if header is None: + raise BufferError("Invalid header") + + header_json = json.loads(header) + data = header_json["__metadata__"] if "__metadata__" in header_json else None + + if data is not None: + for key, value in data.items(): + if isinstance(value, str) and value.startswith('{') and value.endswith('}'): + try: + value_as_json = json.loads(value) + data[key] = value_as_json + except Exception: + print(f'metdata for field {key} did not parse as json') + except requests.exceptions.RequestException as e: + print(e) + data = None + + return data + + +def get_folder_path(file: str, model_type="loras"): + """Gets the file path ensuring it exists.""" + file_path = folder_paths.get_full_path(model_type, file) + if file_path and not path_exists(file_path): + file_path = os.path.abspath(file_path) + if not path_exists(file_path): + file_path = None + return file_path + + +def _get_sha256_hash(file_path: str): + """Returns the hash for the file.""" + if not file_path or not path_exists(file_path): + return None + file_hash = None + sha256_hash = hashlib.sha256() + with open(file_path, "rb") as f: + # Read and update hash string value in blocks of 4K + for byte_block in iter(lambda: f.read(4096), b""): + sha256_hash.update(byte_block) + file_hash = sha256_hash.hexdigest() + return file_hash + + +async def set_model_info_partial(file: str, info_data_partial, model_type="loras"): + """Sets partial data into the existing model info data.""" + info_data = await get_model_info(file, model_type=model_type, default={}) + info_data = {**info_data, **info_data_partial} + save_model_info(file, info_data, model_type=model_type) + + +def save_model_info(file: str, info_data, model_type="loras"): + """Saves the model info alongside the model itself.""" + file_path = get_folder_path(file, model_type) + if file_path is None: + return + try_info_path = f'{file_path}.rgthree-info.json' + save_json_file(try_info_path, info_data) diff --git a/rgthree-comfy/py/utils_server.py b/rgthree-comfy/py/utils_server.py new file mode 100644 index 0000000000000000000000000000000000000000..d6979103bb173ac8889b8d588f36209fae1495b1 --- /dev/null +++ b/rgthree-comfy/py/utils_server.py @@ -0,0 +1,8 @@ +def get_param(request, param, default=None): + """Gets a param from a request.""" + return request.rel_url.query[param] if param in request.rel_url.query else default + +def is_param_falsy(request, param): + """Determines if a param is explicitly 0 or false.""" + val = get_param(request, param) + return val is not None and (val == "0" or val.upper() == "FALSE") diff --git a/rgthree-comfy/py/utils_userdata.py b/rgthree-comfy/py/utils_userdata.py new file mode 100644 index 0000000000000000000000000000000000000000..77d95477b4778e38a82e6d7400a66bd5e540e5de --- /dev/null +++ b/rgthree-comfy/py/utils_userdata.py @@ -0,0 +1,50 @@ +import os + +from .utils import load_json_file, path_exists, save_json_file + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +USERDATA = os.path.join(THIS_DIR, '..', 'userdata') + + +def read_userdata_file(rel_path: str): + """Reads a file from the userdata directory.""" + file_path = clean_path(rel_path) + if path_exists(file_path): + with open(file_path, 'r', encoding='UTF-8') as file: + return file.read() + return None + + +def save_userdata_file(rel_path: str, content: str): + """Saves a file from the userdata directory.""" + file_path = clean_path(rel_path) + with open(file_path, 'w+', encoding='UTF-8') as file: + file.write(content) + + +def delete_userdata_file(rel_path: str): + """Deletes a file from the userdata directory.""" + file_path = clean_path(rel_path) + if os.path.isfile(file_path): + os.remove(file_path) + + +def read_userdata_json(rel_path: str): + """Reads a json file from the userdata directory.""" + file_path = clean_path(rel_path) + return load_json_file(file_path) + + +def save_userdata_json(rel_path: str, data: dict): + """Saves a json file from the userdata directory.""" + file_path = clean_path(rel_path) + return save_json_file(file_path, data) + + +def clean_path(rel_path: str): + """Cleans a relative path by splitting on forward slash and os.path.joining.""" + cleaned = USERDATA + paths = rel_path.split('/') + for path in paths: + cleaned = os.path.join(cleaned, path) + return cleaned diff --git a/rgthree-comfy/requirements.txt b/rgthree-comfy/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/rgthree-comfy/rgthree_config.json.default b/rgthree-comfy/rgthree_config.json.default new file mode 100644 index 0000000000000000000000000000000000000000..5d17feafd13331ef995c308cb432f28d233dac9b --- /dev/null +++ b/rgthree-comfy/rgthree_config.json.default @@ -0,0 +1,61 @@ +// COPY THIS FILE BEFORE MAKING CHANGES TO: rgthree_config.json +{ + "log_level": "WARN", + "features": { + "patch_recursive_execution": true, + "show_alerts_for_corrupt_workflows": false, + "monitor_for_corrupt_links": false, + "menu_queue_selected_nodes": true, + "menu_auto_nest": { + "subdirs": null, + "threshold": 20 + }, + "menu_bookmarks": { + "enabled": true + }, + "group_header_fast_toggle": { + "enabled": null, + "toggles": ["mute", "bypass"], + "show": "hover" + }, + "progress_bar": { + "enabled": true, + "height": 16, + "position": "top" + }, + "comfy_top_bar_menu": { + "enabled": true, + "button_bookmarks": { + "enabled": true + } + }, + // Allows for dragging and dropping a workflow (image, json) onto an individual node to import + // that specific node's widgets if it also exists in the dropped workflow (same id, type). + "import_individual_nodes": { + "enabled": null + }, + // Enables invokeExtensionsAsync for rgthree-nodes allowing other extensions to hook into the + // nodes like the default ComfyNodes. This was not possible before Apr 2024, so it's a config + // entry in case it causes issues. This is only for the nodeCreated event/function as of now. + "invoke_extensions_async": { + "node_created": true + } + }, + "nodes": { + "reroute": { + "default_width": 40, + "default_height": 30, + "default_resizable": false, + "default_layout": ["Left", "Right"], + "fast_reroute": { + "enabled": true, + "key_create_while_dragging_link" : "Shift + R", + "key_rotate": "Shift + A", + "key_resize": "Shift + X", + "key_move": "Shift + Z", + "key_connections_input": "Shift + S", + "key_connections_output": "Shift + D" + } + } + } +} diff --git a/rgthree-comfy/src_web/comfyui/any_switch.ts b/rgthree-comfy/src_web/comfyui/any_switch.ts new file mode 100644 index 0000000000000000000000000000000000000000..aaf3de6c6963b4a0ce4f7eb660e17d1c6746cf2e --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/any_switch.ts @@ -0,0 +1,101 @@ +import type { INodeInputSlot, INodeOutputSlot, LLink } from "typings/litegraph.js"; +import type { ComfyApp, ComfyNodeConstructor, ComfyObjectInfo } from "typings/comfy.js"; + +import { app } from "scripts/app.js"; +import { IoDirection, addConnectionLayoutSupport, followConnectionUntilType } from "./utils.js"; +import { RgthreeBaseServerNode } from "./base_node.js"; +import { NodeTypesString } from "./constants.js"; +import { removeUnusedInputsFromEnd } from "./utils_inputs_outputs.js"; +import { debounce } from "rgthree/common/shared_utils.js"; + +class RgthreeAnySwitch extends RgthreeBaseServerNode { + static override title = NodeTypesString.ANY_SWITCH; + static override type = NodeTypesString.ANY_SWITCH; + static comfyClass = NodeTypesString.ANY_SWITCH; + + private stabilizeBound = this.stabilize.bind(this); + private nodeType: string | string[] | null = null; + + constructor(title = RgthreeAnySwitch.title) { + super(title); + // Adding five. Note, configure will add as many as was in the stored workflow automatically. + this.addAnyInput(5); + } + + override onConnectionsChange( + type: number, + slotIndex: number, + isConnected: boolean, + linkInfo: LLink, + ioSlot: INodeOutputSlot | INodeInputSlot, + ) { + super.onConnectionsChange?.(type, slotIndex, isConnected, linkInfo, ioSlot); + this.scheduleStabilize(); + } + + onConnectionsChainChange() { + this.scheduleStabilize(); + } + + scheduleStabilize(ms = 64) { + return debounce(this.stabilizeBound, ms); + } + + private addAnyInput(num = 1) { + for (let i = 0; i < num; i++) { + this.addInput( + `any_${String(this.inputs.length + 1).padStart(2, "0")}`, + (this.nodeType || "*") as string, + ); + } + } + + stabilize() { + // First, clean up the dynamic number of inputs. + removeUnusedInputsFromEnd(this, 4); + this.addAnyInput(); + + // We prefer the inputs, then the output. + let connectedType = followConnectionUntilType(this, IoDirection.INPUT, undefined, true); + if (!connectedType) { + connectedType = followConnectionUntilType(this, IoDirection.OUTPUT, undefined, true); + } + // TODO: What this doesn't do is broadcast to other nodes when its type changes. Reroute node + // does, but, for now, if this was connected to another Any Switch, say, the second one wouldn't + // change its type when the first does. The user would need to change the connections. + this.nodeType = connectedType?.type || "*"; + for (const input of this.inputs) { + input.type = this.nodeType as string; // So, types can indeed be arrays,, + } + for (const output of this.outputs) { + output.type = this.nodeType as string; // So, types can indeed be arrays,, + output.label = + output.type === "RGTHREE_CONTEXT" + ? "CONTEXT" + : Array.isArray(this.nodeType) || this.nodeType.includes(",") + ? connectedType?.label || String(this.nodeType) + : String(this.nodeType); + } + } + + static override setUp(comfyClass: ComfyNodeConstructor, nodeData: ComfyObjectInfo) { + RgthreeBaseServerNode.registerForOverride(comfyClass, nodeData, RgthreeAnySwitch); + addConnectionLayoutSupport(RgthreeAnySwitch, app, [ + ["Left", "Right"], + ["Right", "Left"], + ]); + } +} + +app.registerExtension({ + name: "rgthree.AnySwitch", + async beforeRegisterNodeDef( + nodeType: ComfyNodeConstructor, + nodeData: ComfyObjectInfo, + app: ComfyApp, + ) { + if (nodeData.name === "Any Switch (rgthree)") { + RgthreeAnySwitch.setUp(nodeType, nodeData); + } + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/base_any_input_connected_node.ts b/rgthree-comfy/src_web/comfyui/base_any_input_connected_node.ts new file mode 100644 index 0000000000000000000000000000000000000000..dbe6a4dfb4aedae2d931104836ad4f8a7ceda20c --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/base_any_input_connected_node.ts @@ -0,0 +1,325 @@ +import type { RgthreeBaseVirtualNodeConstructor } from "typings/rgthree.js"; +import type { + Vector2, + LLink, + INodeInputSlot, + INodeOutputSlot, + LGraphNode as TLGraphNode, + IWidget, +} from "typings/litegraph.js"; + +import { app } from "scripts/app.js"; +import { RgthreeBaseVirtualNode } from "./base_node.js"; +import { rgthree } from "./rgthree.js"; + +import { + PassThroughFollowing, + addConnectionLayoutSupport, + addMenuItem, + getConnectedInputNodes, + getConnectedInputNodesAndFilterPassThroughs, + getConnectedOutputNodes, + getConnectedOutputNodesAndFilterPassThroughs, +} from "./utils.js"; + +/** + * A Virtual Node that allows any node's output to connect to it. + */ +export class BaseAnyInputConnectedNode extends RgthreeBaseVirtualNode { + override isVirtualNode = true; + + /** + * Whether inputs show the immediate nodes, or follow and show connected nodes through + * passthrough nodes. + */ + readonly inputsPassThroughFollowing: PassThroughFollowing = PassThroughFollowing.NONE; + + debouncerTempWidth: number = 0; + schedulePromise: Promise | null = null; + + constructor(title = BaseAnyInputConnectedNode.title) { + super(title); + } + + override onConstructed() { + this.addInput("", "*"); + return super.onConstructed(); + } + + /** Schedules a promise to run a stabilization. */ + scheduleStabilizeWidgets(ms = 100) { + if (!this.schedulePromise) { + this.schedulePromise = new Promise((resolve) => { + setTimeout(() => { + this.schedulePromise = null; + this.doStablization(); + resolve(); + }, ms); + }); + } + return this.schedulePromise; + } + + override clone() { + const cloned = super.clone(); + // Copying to clipboard (and also, creating node templates) work by cloning nodes and, for some + // reason, it manually manipulates the cloned data. So, we want to keep the present input slots + // so if it's pasted/templatized the data is correct. Otherwise, clear the inputs and so the new + // node is ready to go, fresh. + if (!rgthree.canvasCurrentlyCopyingToClipboardWithMultipleNodes) { + while (cloned.inputs.length > 1) { + cloned.removeInput(cloned.inputs.length - 1); + } + if (cloned.inputs[0]) { + cloned.inputs[0].label = ""; + } + } + return cloned; + } + /** + * Ensures we have at least one empty input at the end. + */ + stabilizeInputsOutputs() { + const hasEmptyInput = !this.inputs[this.inputs.length - 1]?.link; + if (!hasEmptyInput) { + this.addInput("", "*"); + } + for (let index = this.inputs.length - 2; index >= 0; index--) { + const input = this.inputs[index]!; + if (!input.link) { + this.removeInput(index); + } else { + const node = getConnectedInputNodesAndFilterPassThroughs( + this, + this, + index, + this.inputsPassThroughFollowing, + )[0]; + input.name = node?.title || ""; + } + } + } + + /** + * Stabilizes the node's inputs and widgets. + */ + private doStablization() { + if (!this.graph) { + return; + } + // When we add/remove widgets, litegraph is going to mess up the size, so we + // store it so we can retrieve it in computeSize. Hacky.. + (this as any)._tempWidth = this.size[0]; + + const linkedNodes = getConnectedInputNodesAndFilterPassThroughs(this); + this.stabilizeInputsOutputs(); + + this.handleLinkedNodesStabilization(linkedNodes); + + app.graph.setDirtyCanvas(true, true); + + // Schedule another stabilization in the future. + this.scheduleStabilizeWidgets(500); + } + + handleLinkedNodesStabilization(linkedNodes: TLGraphNode[]) { + linkedNodes; // No-op, but makes overridding in VSCode cleaner. + throw new Error("handleLinkedNodesStabilization should be overridden."); + } + + onConnectionsChainChange() { + this.scheduleStabilizeWidgets(); + } + + override onConnectionsChange( + type: number, + index: number, + connected: boolean, + linkInfo: LLink, + ioSlot: INodeOutputSlot | INodeInputSlot, + ) { + super.onConnectionsChange && + super.onConnectionsChange(type, index, connected, linkInfo, ioSlot); + if (!linkInfo) return; + // Follow outputs to see if we need to trigger an onConnectionChange. + const connectedNodes = getConnectedOutputNodesAndFilterPassThroughs(this); + for (const node of connectedNodes) { + if ((node as BaseAnyInputConnectedNode).onConnectionsChainChange) { + (node as BaseAnyInputConnectedNode).onConnectionsChainChange(); + } + } + this.scheduleStabilizeWidgets(); + } + + override removeInput(slot: number) { + (this as any)._tempWidth = this.size[0]; + return super.removeInput(slot); + } + + override addInput(name: string, type: string | -1, extra_info?: Partial) { + (this as any)._tempWidth = this.size[0]; + return super.addInput(name, type, extra_info); + } + + override addWidget( + type: T["type"], + name: string, + value: T["value"], + callback?: T["callback"] | string, + options?: T["options"], + ) { + (this as any)._tempWidth = this.size[0]; + return super.addWidget(type, name, value, callback, options); + } + + /** + * Guess this doesn't exist in Litegraph... + */ + override removeWidget(widgetOrSlot?: IWidget | number) { + (this as any)._tempWidth = this.size[0]; + super.removeWidget(widgetOrSlot); + } + + override computeSize(out: Vector2) { + let size = super.computeSize(out); + if ((this as any)._tempWidth) { + size[0] = (this as any)._tempWidth; + // We sometimes get repeated calls to compute size, so debounce before clearing. + this.debouncerTempWidth && clearTimeout(this.debouncerTempWidth); + this.debouncerTempWidth = setTimeout(() => { + (this as any)._tempWidth = null; + }, 32); + } + // If we're collapsed, then subtract the total calculated height of the other input slots. + if (this.properties["collapse_connections"]) { + const rows = Math.max(this.inputs?.length || 0, this.outputs?.length || 0, 1) - 1; + size[1] = size[1] - rows * LiteGraph.NODE_SLOT_HEIGHT; + } + setTimeout(() => { + app.graph.setDirtyCanvas(true, true); + }, 16); + return size; + } + + /** + * When we connect our output, check our inputs and make sure we're not trying to connect a loop. + */ + override onConnectOutput( + outputIndex: number, + inputType: string | -1, + inputSlot: INodeInputSlot, + inputNode: TLGraphNode, + inputIndex: number, + ): boolean { + let canConnect = true; + if (super.onConnectOutput) { + canConnect = super.onConnectOutput(outputIndex, inputType, inputSlot, inputNode, inputIndex); + } + if (canConnect) { + const nodes = getConnectedInputNodes(this); // We want passthrough nodes, since they will loop. + if (nodes.includes(inputNode)) { + alert( + `Whoa, whoa, whoa. You've just tried to create a connection that loops back on itself, ` + + `a situation that could create a time paradox, the results of which could cause a ` + + `chain reaction that would unravel the very fabric of the space time continuum, ` + + `and destroy the entire universe!`, + ); + canConnect = false; + } + } + return canConnect; + } + + override onConnectInput( + inputIndex: number, + outputType: string | -1, + outputSlot: INodeOutputSlot, + outputNode: TLGraphNode, + outputIndex: number, + ): boolean { + let canConnect = true; + if (super.onConnectInput) { + canConnect = super.onConnectInput( + inputIndex, + outputType, + outputSlot, + outputNode, + outputIndex, + ); + } + if (canConnect) { + const nodes = getConnectedOutputNodes(this); // We want passthrough nodes, since they will loop. + if (nodes.includes(outputNode)) { + alert( + `Whoa, whoa, whoa. You've just tried to create a connection that loops back on itself, ` + + `a situation that could create a time paradox, the results of which could cause a ` + + `chain reaction that would unravel the very fabric of the space time continuum, ` + + `and destroy the entire universe!`, + ); + canConnect = false; + } + } + return canConnect; + } + + /** + * If something is dropped on us, just add it to the bottom. onConnectInput should already cancel + * if it's disallowed. + */ + override connectByTypeOutput( + slot: string | number, + sourceNode: TLGraphNode, + sourceSlotType: string, + optsIn: string, + ): T | null { + const lastInput = this.inputs[this.inputs.length - 1]; + if (!lastInput?.link && lastInput?.type === "*") { + var sourceSlot = sourceNode.findOutputSlotByType(sourceSlotType, false, true); + return sourceNode.connect(sourceSlot, this, slot); + } + return super.connectByTypeOutput(slot, sourceNode, sourceSlotType, optsIn); + } + + static override setUp() { + super.setUp(); + addConnectionLayoutSupport(this, app, [ + ["Left", "Right"], + ["Right", "Left"], + ]); + addMenuItem(this, app, { + name: (node) => + `${node.properties?.["collapse_connections"] ? "Show" : "Collapse"} Connections`, + property: "collapse_connections", + prepareValue: (_value, node) => !node.properties?.["collapse_connections"], + callback: (_node) => { + app.graph.setDirtyCanvas(true, true); + }, + }); + } +} + +// Ok, hack time! LGraphNode's connectByType is powerful, but for our nodes, that have multiple "*" +// input types, it seems it just takes the first one, and disconnects it. I'd rather we don't do +// that and instead take the next free one. If that doesn't work, then we'll give it to the old +// method. +const oldLGraphNodeConnectByType = LGraphNode.prototype.connectByType; +LGraphNode.prototype.connectByType = function connectByType( + slot: string | number, + sourceNode: TLGraphNode, + sourceSlotType: string, + optsIn: string, +): T | null { + // If we're droppiong on a node, and the last input is free and an "*" type, then connect there + // first... + if (sourceNode.inputs) { + for (const [index, input] of sourceNode.inputs.entries()) { + if (!input.link && input.type === "*") { + this.connect(slot, sourceNode, index); + return null; + } + } + } + return ((oldLGraphNodeConnectByType && + oldLGraphNodeConnectByType.call(this, slot, sourceNode, sourceSlotType, optsIn)) || + null) as T; +}; diff --git a/rgthree-comfy/src_web/comfyui/base_node.ts b/rgthree-comfy/src_web/comfyui/base_node.ts new file mode 100644 index 0000000000000000000000000000000000000000..90f74579c5bc1e145e56a05a85ff997823275cdb --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/base_node.ts @@ -0,0 +1,460 @@ +import type { ComfyNodeConstructor, ComfyObjectInfo, NodeMode } from "typings/comfy.js"; +import type { + IWidget, + SerializedLGraphNode, + LGraphNode as TLGraphNode, + LGraphCanvas, + ContextMenuItem, + INodeOutputSlot, + INodeInputSlot, +} from "typings/litegraph.js"; +import type { RgthreeBaseServerNodeConstructor, RgthreeBaseVirtualNodeConstructor } from "typings/rgthree.js"; + +import { ComfyWidgets } from "scripts/widgets.js"; +import { SERVICE as KEY_EVENT_SERVICE } from "./services/key_events_services.js"; +import { app } from "scripts/app.js"; +import { LogLevel, rgthree } from "./rgthree.js"; +import { addHelpMenuItem } from "./utils.js"; +import { RgthreeHelpDialog } from "rgthree/common/dialog.js"; +import { + importIndividualNodesInnerOnDragDrop, + importIndividualNodesInnerOnDragOver, +} from "./feature_import_individual_nodes.js"; + +/** + * A base node with standard methods, directly extending the LGraphNode. + * This can be used for ui-nodes and a further base for server nodes. + */ +export abstract class RgthreeBaseNode extends LGraphNode { + /** + * Action strings that can be exposed and triggered from other nodes, like Fast Actions Button. + */ + static exposedActions: string[] = []; + + static override title: string = "__NEED_CLASS_TITLE__"; + static category = "rgthree"; + static _category = "rgthree"; // `category` seems to get reset by comfy, so reset to this after. + + /** + * The comfyClass is property ComfyUI and extensions may care about, even through it is only for + * server nodes. RgthreeBaseServerNode below overrides this with the expected value and we just + * set it here so extensions that are none the wiser don't break on some unchecked string method + * call on an undefined calue. + */ + comfyClass: string = "__NEED_COMFY_CLASS__"; + + /** Used by the ComfyUI-Manager badge. */ + readonly nickname = "rgthree"; + /** Are we a virtual node? */ + readonly isVirtualNode: boolean = false; + /** Are we able to be dropped on (if config is enabled too). */ + isDropEnabled = false; + /** A state member determining if we're currently removed. */ + removed = false; + /** A state member determining if we're currently "configuring."" */ + configuring = false; + /** A temporary width value that can be used to ensure compute size operates correctly. */ + _tempWidth = 0; + + /** Private Mode member so we can override the setter/getter and call an `onModeChange`. */ + private mode_: NodeMode; + /** An internal bool set when `onConstructed` is run. */ + private __constructed__ = false; + /** The help dialog. */ + private helpDialog: RgthreeHelpDialog | null = null; + + constructor(title = RgthreeBaseNode.title, skipOnConstructedCall = true) { + super(title); + if (title == "__NEED_CLASS_TITLE__") { + throw new Error("RgthreeBaseNode needs overrides."); + } + // Ensure these exist since some other extensions will break in their onNodeCreated. + this.widgets = this.widgets || []; + this.properties = this.properties || {}; + + // Some checks we want to do after we're constructed, looking that data is set correctly and + // that our base's `onConstructed` was called (if not, set a DEV warning). + setTimeout(() => { + // Check we have a comfyClass defined. + if (this.comfyClass == "__NEED_COMFY_CLASS__") { + throw new Error("RgthreeBaseNode needs a comfy class override."); + } + // Ensure we've called onConstructed before we got here. + this.checkAndRunOnConstructed(); + }); + } + + private checkAndRunOnConstructed() { + if (!this.__constructed__) { + this.onConstructed(); + const [n, v] = rgthree.logger.logParts( + LogLevel.DEV, + `[RgthreeBaseNode] Child class did not call onConstructed for "${this.type}.`, + ); + console[n]?.(...v); + } + return this.__constructed__; + } + + onDragOver(e: DragEvent): boolean { + if (!this.isDropEnabled) return false; + return importIndividualNodesInnerOnDragOver(this, e); + } + + async onDragDrop(e: DragEvent): Promise { + if (!this.isDropEnabled) return false; + return importIndividualNodesInnerOnDragDrop(this, e); + } + + /** + * When a node is finished with construction, we must call this. Failure to do so will result in + * an error message from the timeout in this base class. This is broken out and becomes the + * responsibility of the child class because + */ + onConstructed() { + if (this.__constructed__) return false; + // This is kinda a hack, but if this.type is still null, then set it to undefined to match. + this.type = this.type ?? undefined; + this.__constructed__ = true; + rgthree.invokeExtensionsAsync("nodeCreated", this); + return this.__constructed__; + } + + override configure(info: SerializedLGraphNode): void { + this.configuring = true; + super.configure(info); + // Fix https://github.com/comfyanonymous/ComfyUI/issues/1448 locally. + // Can removed when fixed and adopted. + for (const w of this.widgets || []) { + w.last_y = w.last_y || 0; + } + this.configuring = false; + } + + /** + * Override clone for, at the least, deep-copying properties. + */ + override clone() { + const cloned = super.clone(); + // This is whild, but LiteGraph clone doesn't deep clone data, so we will. We'll use structured + // clone, which most browsers in 2022 support, but but we'll check. + if (cloned.properties && !!window.structuredClone) { + cloned.properties = structuredClone(cloned.properties); + } + return cloned; + } + + // @ts-ignore - Changing the property to an accessor here seems to work, but ts compiler complains. + override set mode(mode: NodeMode) { + if (this.mode_ != mode) { + const oldMode = this.mode_; + this.mode_ = mode; + this.onModeChange(oldMode, mode); + } + } + override get mode() { + return this.mode_; + } + + /** When a mode change, we want all connected nodes to match. */ + onModeChange(from: NodeMode, to: NodeMode) { + // Override + } + + /** + * Given a string, do something. At the least, handle any `exposedActions` that may be called and + * passed into from other nodes, like Fast Actions Button + */ + async handleAction(action: string) { + action; // No-op. Should be overridden but OK if not. + } + + /** + * Guess this doesn't exist in Litegraph... + */ + removeWidget(widgetOrSlot?: IWidget | number) { + if (typeof widgetOrSlot === "number") { + this.widgets.splice(widgetOrSlot, 1); + } else if (widgetOrSlot) { + const index = this.widgets.indexOf(widgetOrSlot); + if (index > -1) { + this.widgets.splice(index, 1); + } + } + } + + /** + * A default version of the logive when a node does not set `getSlotMenuOptions`. This is + * necessary because child nodes may want to define getSlotMenuOptions but LiteGraph then won't do + * it's default logic. This bakes it so child nodes can call this instead (and this doesn't set + * getSlotMenuOptions for all child nodes in case it doesn't exist). + */ + defaultGetSlotMenuOptions(slot: { + input?: INodeInputSlot; + output?: INodeOutputSlot; + }): ContextMenuItem[] | null { + const menu_info: ContextMenuItem[] = []; + if (slot?.output?.links?.length) { + menu_info.push({ content: "Disconnect Links", slot: slot }); + } + let inputOrOutput = slot.input || slot.output; + if (inputOrOutput) { + if (inputOrOutput.removable) { + menu_info.push( + inputOrOutput.locked ? { content: "Cannot remove" } : { content: "Remove Slot", slot }, + ); + } + if (!inputOrOutput.nameLocked) { + menu_info.push({ content: "Rename Slot", slot }); + } + } + return menu_info; + } + + override onRemoved(): void { + super.onRemoved?.(); + this.removed = true; + } + + static setUp(...args: any[]) { + // No-op. + } + + /** + * A function to provide help text to be overridden. + */ + getHelp() { + return ""; + } + + showHelp() { + const help = this.getHelp() || (this.constructor as any).help; + if (help) { + this.helpDialog = new RgthreeHelpDialog(this, help).show(); + this.helpDialog.addEventListener("close", (e) => { + this.helpDialog = null; + }); + } + } + + override onKeyDown(event: KeyboardEvent): void { + KEY_EVENT_SERVICE.handleKeyDownOrUp(event); + if (event.key == "?" && !this.helpDialog) { + this.showHelp(); + } + } + + override onKeyUp(event: KeyboardEvent): void { + KEY_EVENT_SERVICE.handleKeyDownOrUp(event); + } + + override getExtraMenuOptions(canvas: LGraphCanvas, options: ContextMenuItem[]): void { + // Some other extensions override getExtraMenuOptions on the nodeType as it comes through from + // the server, so we can call out to that if we don't have our own. + if (super.getExtraMenuOptions) { + super.getExtraMenuOptions?.apply(this, [canvas, options]); + } else if ((this.constructor as any).nodeType?.prototype?.getExtraMenuOptions) { + (this.constructor as any).nodeType?.prototype?.getExtraMenuOptions?.apply(this, [ + canvas, + options, + ]); + } + // If we have help content, then add a menu item. + const help = this.getHelp() || (this.constructor as any).help; + if (help) { + addHelpMenuItem(this, help, options); + } + } +} + +/** + * A virtual node. Right now, this is just a wrapper for RgthreeBaseNode (which was the initial + * base virtual node). + * + * TODO: Make RgthreeBaseNode private and move all virtual nodes to this class; cleanup + * RgthreeBaseNode assumptions that its virtual. + */ +export class RgthreeBaseVirtualNode extends RgthreeBaseNode { + override isVirtualNode = true; + + constructor(title = RgthreeBaseNode.title) { + super(title, false); + } + + static override setUp() { + if (!this.type) { + throw new Error(`Missing type for RgthreeBaseVirtualNode: ${this.title}`); + } + LiteGraph.registerNodeType(this.type, this); + if (this._category) { + this.category = this._category; + } + } +} + +/** + * A base node with standard methods, extending the LGraphNode. + * This is somewhat experimental, but if comfyui is going to keep breaking widgets and inputs, it + * seems safer than NOT overriding. + */ +export class RgthreeBaseServerNode extends RgthreeBaseNode { + static nodeData: ComfyObjectInfo | null = null; + static nodeType: ComfyNodeConstructor | null = null; + + // Drop is enabled by default for server nodes. + override isDropEnabled = true; + + constructor(title: string) { + super(title, true); + this.serialize_widgets = true; + this.setupFromServerNodeData(); + this.onConstructed(); + } + + getWidgets() { + return ComfyWidgets; + } + + /** + * This takes the server data and builds out the inputs, outputs and widgets. It's similar to the + * ComfyNode constructor in registerNodes in ComfyUI's app.js, but is more stable and thus + * shouldn't break as often when it modifyies widgets and types. + */ + async setupFromServerNodeData() { + const nodeData = (this.constructor as any).nodeData; + if (!nodeData) { + throw Error("No node data"); + } + + // Necessary for serialization so Comfy backend can check types. + // Serialized as `class_type`. See app.js#graphToPrompt + this.comfyClass = nodeData.name; + + let inputs = nodeData["input"]["required"]; + if (nodeData["input"]["optional"] != undefined) { + inputs = Object.assign({}, inputs, nodeData["input"]["optional"]); + } + + const WIDGETS = this.getWidgets(); + + const config: { minWidth: number; minHeight: number; widget?: null | { options: any } } = { + minWidth: 1, + minHeight: 1, + widget: null, + }; + for (const inputName in inputs) { + const inputData = inputs[inputName]; + const type = inputData[0]; + // If we're forcing the input, just do it now and forget all that widget stuff. + // This is one of the differences from ComfyNode and provides smoother experience for inputs + // that are going to remain inputs anyway. + // Also, it fixes https://github.com/comfyanonymous/ComfyUI/issues/1404 (for rgthree nodes) + if (inputData[1]?.forceInput) { + this.addInput(inputName, type); + } else { + let widgetCreated = true; + if (Array.isArray(type)) { + // Enums + Object.assign(config, WIDGETS.COMBO(this, inputName, inputData, app) || {}); + } else if (`${type}:${inputName}` in WIDGETS) { + // Support custom widgets by Type:Name + Object.assign( + config, + WIDGETS[`${type}:${inputName}`]!(this, inputName, inputData, app) || {}, + ); + } else if (type in WIDGETS) { + // Standard type widgets + Object.assign(config, WIDGETS[type]!(this, inputName, inputData, app) || {}); + } else { + // Node connection inputs + this.addInput(inputName, type); + widgetCreated = false; + } + + // Don't actually need this right now, but ported it over from ComfyWidget. + if (widgetCreated && inputData[1]?.forceInput && config?.widget) { + if (!config.widget.options) config.widget.options = {}; + config.widget.options.forceInput = inputData[1].forceInput; + } + if (widgetCreated && inputData[1]?.defaultInput && config?.widget) { + if (!config.widget.options) config.widget.options = {}; + config.widget.options.defaultInput = inputData[1].defaultInput; + } + } + } + + for (const o in nodeData["output"]) { + let output = nodeData["output"][o]; + if (output instanceof Array) output = "COMBO"; + const outputName = nodeData["output_name"][o] || output; + const outputShape = nodeData["output_is_list"][o] + ? LiteGraph.GRID_SHAPE + : LiteGraph.CIRCLE_SHAPE; + this.addOutput(outputName, output, { shape: outputShape }); + } + + const s = this.computeSize(); + s[0] = Math.max(config.minWidth, s[0] * 1.5); + s[1] = Math.max(config.minHeight, s[1]); + this.size = s; + this.serialize_widgets = true; + } + + static __registeredForOverride__: boolean = false; + static registerForOverride( + comfyClass: ComfyNodeConstructor, + nodeData: ComfyObjectInfo, + rgthreeClass: RgthreeBaseServerNodeConstructor, + ) { + if (OVERRIDDEN_SERVER_NODES.has(comfyClass)) { + throw Error( + `Already have a class to override ${ + comfyClass.type || comfyClass.name || comfyClass.title + }`, + ); + } + OVERRIDDEN_SERVER_NODES.set(comfyClass, rgthreeClass); + // Mark the rgthreeClass as `__registeredForOverride__` because ComfyUI will repeatedly call + // this and certain setups will only want to setup once (like adding context menus, etc). + if (!rgthreeClass.__registeredForOverride__) { + rgthreeClass.__registeredForOverride__ = true; + rgthreeClass.nodeType = comfyClass; + rgthreeClass.nodeData = nodeData; + rgthreeClass.onRegisteredForOverride(comfyClass, rgthreeClass); + } + } + + static onRegisteredForOverride(comfyClass: any, rgthreeClass: any) { + // To be overridden + } +} + +/** + * Keeps track of the rgthree-comfy nodes that come from the server (and want to be ComfyNodes) that + * we override into a own, more flexible and cleaner nodes. + */ +const OVERRIDDEN_SERVER_NODES = new Map(); + +const oldregisterNodeType = LiteGraph.registerNodeType; +/** + * ComfyUI calls registerNodeType with its ComfyNode, but we don't trust that will remain stable, so + * we need to identify it, intercept it, and supply our own class for the node. + */ +LiteGraph.registerNodeType = async function (nodeId: string, baseClass: any) { + const clazz = OVERRIDDEN_SERVER_NODES.get(baseClass) || baseClass; + if (clazz !== baseClass) { + const classLabel = clazz.type || clazz.name || clazz.title; + const [n, v] = rgthree.logger.logParts( + LogLevel.DEBUG, + `${nodeId}: replacing default ComfyNode implementation with custom ${classLabel} class.`, + ); + console[n]?.(...v); + // Note, we don't currently call our rgthree.invokeExtensionsAsync w/ beforeRegisterNodeDef as + // this runs right after that. However, this does mean that extensions cannot actually change + // anything about overriden server rgthree nodes in their beforeRegisterNodeDef (as when comfy + // calls it, it's for the wrong ComfyNode class). Calling it here, however, would re-run + // everything causing more issues than not. If we wanted to support beforeRegisterNodeDef then + // it would mean rewriting ComfyUI's registerNodeDef which, frankly, is not worth it. + } + return oldregisterNodeType.call(LiteGraph, nodeId, clazz); +}; diff --git a/rgthree-comfy/src_web/comfyui/base_node_collector.ts b/rgthree-comfy/src_web/comfyui/base_node_collector.ts new file mode 100644 index 0000000000000000000000000000000000000000..509c955a4e647e4dac31744ead1ff877fbdae962 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/base_node_collector.ts @@ -0,0 +1,98 @@ +import type { INodeOutputSlot, LGraphNode } from "typings/litegraph.js"; +import { rgthree } from "./rgthree.js"; +import { BaseAnyInputConnectedNode } from "./base_any_input_connected_node.js"; +import { + PassThroughFollowing, + getConnectedInputNodes, + getConnectedInputNodesAndFilterPassThroughs, + shouldPassThrough, +} from "./utils.js"; + +/** + * Base collector node that monitors changing inputs and outputs. + */ +export class BaseCollectorNode extends BaseAnyInputConnectedNode { + /** + * We only want to show nodes through re_route nodes, other pass through nodes show each input. + */ + override readonly inputsPassThroughFollowing: PassThroughFollowing = + PassThroughFollowing.REROUTE_ONLY; + + readonly logger = rgthree.newLogSession("[BaseCollectorNode]"); + + constructor(title?: string) { + super(title); + } + + override clone() { + const cloned = super.clone(); + return cloned; + } + + override handleLinkedNodesStabilization(linkedNodes: LGraphNode[]): void { + // No-op, no widgets. + } + + /** + * When we connect an input, check to see if it's already connected and cancel it. + */ + override onConnectInput( + inputIndex: number, + outputType: string | -1, + outputSlot: INodeOutputSlot, + outputNode: LGraphNode, + outputIndex: number, + ): boolean { + let canConnect = super.onConnectInput( + inputIndex, + outputType, + outputSlot, + outputNode, + outputIndex, + ); + if (canConnect) { + const allConnectedNodes = getConnectedInputNodes(this); // We want passthrough nodes, since they will loop. + const nodesAlreadyInSlot = getConnectedInputNodes(this, undefined, inputIndex); + if (allConnectedNodes.includes(outputNode)) { + // If we're connecting to the same slot, then allow it by replacing the one we have. + // const slotsOriginNode = getOriginNodeByLink(this.inputs[inputIndex]?.link); + const [n, v] = this.logger.debugParts( + `${outputNode.title} is already connected to ${this.title}.`, + ); + console[n]?.(...v); + if (nodesAlreadyInSlot.includes(outputNode)) { + const [n, v] = this.logger.debugParts( + `... but letting it slide since it's for the same slot.`, + ); + console[n]?.(...v); + } else { + canConnect = false; + } + } + if (canConnect && shouldPassThrough(outputNode, PassThroughFollowing.REROUTE_ONLY)) { + const connectedNode = getConnectedInputNodesAndFilterPassThroughs( + outputNode, + undefined, + undefined, + PassThroughFollowing.REROUTE_ONLY, + )[0]; + if (connectedNode && allConnectedNodes.includes(connectedNode)) { + // If we're connecting to the same slot, then allow it by replacing the one we have. + const [n, v] = this.logger.debugParts( + `${connectedNode.title} is already connected to ${this.title}.`, + ); + console[n]?.(...v); + if (nodesAlreadyInSlot.includes(connectedNode)) { + const [n, v] = this.logger.debugParts( + `... but letting it slide since it's for the same slot.`, + ); + console[n]?.(...v); + } else { + canConnect = false; + } + } + } + } + return canConnect; + } +} diff --git a/rgthree-comfy/src_web/comfyui/base_node_mode_changer.ts b/rgthree-comfy/src_web/comfyui/base_node_mode_changer.ts new file mode 100644 index 0000000000000000000000000000000000000000..4de801ae373f9f4980eb964c6eb3fff06befdfad --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/base_node_mode_changer.ts @@ -0,0 +1,103 @@ +import type { RgthreeBaseVirtualNodeConstructor } from "typings/rgthree.js"; +import type { + LGraphNode as TLGraphNode, + IWidget, + SerializedLGraphNode, +} from "typings/litegraph.js"; +import { BaseAnyInputConnectedNode } from "./base_any_input_connected_node.js"; +import { PassThroughFollowing } from "./utils.js"; +import { wait } from "rgthree/common/shared_utils.js"; + +export class BaseNodeModeChanger extends BaseAnyInputConnectedNode { + override readonly inputsPassThroughFollowing: PassThroughFollowing = PassThroughFollowing.ALL; + + static collapsible = false; + override isVirtualNode = true; + + // These Must be overriden + readonly modeOn: number = -1; + readonly modeOff: number = -1; + + static "@toggleRestriction" = { + type: "combo", + values: ["default", "max one", "always one"], + }; + + constructor(title?: string) { + super(title); + this.properties["toggleRestriction"] = "default"; + } + + override onConstructed(): boolean { + wait(10).then(() => { + if (this.modeOn < 0 || this.modeOff < 0) { + throw new Error("modeOn and modeOff must be overridden."); + } + }); + this.addOutput("OPT_CONNECTION", "*"); + return super.onConstructed(); + } + + override configure(info: SerializedLGraphNode): void { + // Patch a small issue (~14h) where multiple OPT_CONNECTIONS may have been created. + // https://github.com/rgthree/rgthree-comfy/issues/206 + // TODO: This can probably be removed within a few weeks. + if (info.outputs?.length) { + info.outputs.length = 1; + } + super.configure(info); + } + + override handleLinkedNodesStabilization(linkedNodes: TLGraphNode[]) { + for (const [index, node] of linkedNodes.entries()) { + let widget = this.widgets && this.widgets[index]; + if (!widget) { + // When we add a widget, litegraph is going to mess up the size, so we + // store it so we can retrieve it in computeSize. Hacky.. + (this as any)._tempWidth = this.size[0]; + widget = this.addWidget("toggle", "", false, "", { on: "yes", off: "no" }); + } + node && this.setWidget(widget, node); + } + if (this.widgets && this.widgets.length > linkedNodes.length) { + this.widgets.length = linkedNodes.length; + } + } + + protected setWidget(widget: IWidget, linkedNode: TLGraphNode, forceValue?: boolean) { + const value = forceValue == null ? linkedNode.mode === this.modeOn : forceValue; + widget.name = `Enable ${linkedNode.title}`; + widget.options = { on: "yes", off: "no" }; + widget.value = value; + (widget as any).doModeChange = (forceValue?: boolean, skipOtherNodeCheck?: boolean) => { + let newValue = forceValue == null ? linkedNode.mode === this.modeOff : forceValue; + if (skipOtherNodeCheck !== true) { + if (newValue && this.properties?.["toggleRestriction"]?.includes(" one")) { + for (const widget of this.widgets) { + (widget as any).doModeChange(false, true); + } + } else if (!newValue && this.properties?.["toggleRestriction"] === "always one") { + newValue = this.widgets.every((w) => !w.value || w === widget); + } + } + linkedNode.mode = (newValue ? this.modeOn : this.modeOff) as 1 | 2 | 3 | 4; + widget.value = newValue; + }; + widget.callback = () => { + (widget as any).doModeChange(); + }; + if (forceValue != null) { + linkedNode.mode = (forceValue ? this.modeOn : this.modeOff) as 1 | 2 | 3 | 4; + } + } + + forceWidgetOff(widget: IWidget, skipOtherNodeCheck?: boolean) { + (widget as any).doModeChange(false, skipOtherNodeCheck); + } + forceWidgetOn(widget: IWidget, skipOtherNodeCheck?: boolean) { + (widget as any).doModeChange(true, skipOtherNodeCheck); + } + forceWidgetToggle(widget: IWidget, skipOtherNodeCheck?: boolean) { + (widget as any).doModeChange(!widget.value, skipOtherNodeCheck); + } +} diff --git a/rgthree-comfy/src_web/comfyui/base_power_prompt.ts b/rgthree-comfy/src_web/comfyui/base_power_prompt.ts new file mode 100644 index 0000000000000000000000000000000000000000..e0e262e1feb2a6ccff428cee1071314954c904aa --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/base_power_prompt.ts @@ -0,0 +1,364 @@ +import { api } from "scripts/api.js"; +import type { + LLink, + IComboWidget, + LGraphNode, + INodeOutputSlot, + INodeInputSlot, + IWidget, + SerializedLGraphNode, +} from "typings/litegraph.js"; +import type { ComfyObjectInfo, ComfyGraphNode } from "typings/comfy.js"; +import { wait } from "rgthree/common/shared_utils.js"; +import { rgthree } from "./rgthree.js"; + +/** Wraps a node instance keeping closure without mucking the finicky types. */ +export class PowerPrompt { + readonly isSimple: boolean; + readonly node: ComfyGraphNode; + readonly promptEl: HTMLTextAreaElement; + nodeData: ComfyObjectInfo; + readonly combos: { [key: string]: IComboWidget } = {}; + readonly combosValues: { [key: string]: string[] } = {}; + boundOnFreshNodeDefs!: (event: CustomEvent) => void; + + private configuring = false; + + constructor(node: ComfyGraphNode, nodeData: ComfyObjectInfo) { + this.node = node; + this.node.properties = this.node.properties || {}; + + this.node.properties["combos_filter"] = ""; + + this.nodeData = nodeData; + this.isSimple = this.nodeData.name.includes("Simple"); + + this.promptEl = (node.widgets[0]! as any).inputEl; + this.addAndHandleKeyboardLoraEditWeight(); + + this.patchNodeRefresh(); + + const oldConfigure = this.node.configure; + this.node.configure = (info: SerializedLGraphNode) => { + this.configuring = true; + oldConfigure?.apply(this.node, [info]); + this.configuring = false; + }; + + const oldOnConnectionsChange = this.node.onConnectionsChange; + this.node.onConnectionsChange = ( + type: number, + slotIndex: number, + isConnected: boolean, + link_info: LLink, + _ioSlot: INodeOutputSlot | INodeInputSlot, + ) => { + oldOnConnectionsChange?.apply(this.node, [type, slotIndex, isConnected, link_info, _ioSlot]); + this.onNodeConnectionsChange(type, slotIndex, isConnected, link_info, _ioSlot); + }; + + const oldOnConnectInput = this.node.onConnectInput; + this.node.onConnectInput = ( + inputIndex: number, + outputType: INodeOutputSlot["type"], + outputSlot: INodeOutputSlot, + outputNode: LGraphNode, + outputIndex: number, + ) => { + let canConnect = true; + if (oldOnConnectInput) { + canConnect = oldOnConnectInput.apply(this.node, [ + inputIndex, + outputType, + outputSlot, + outputNode, + outputIndex, + ]); + } + return ( + this.configuring || + rgthree.loadingApiJson || + (canConnect && !this.node.inputs[inputIndex]!.disabled) + ); + }; + + const oldOnConnectOutput = this.node.onConnectOutput; + this.node.onConnectOutput = ( + outputIndex: number, + inputType: INodeInputSlot["type"], + inputSlot: INodeInputSlot, + inputNode: LGraphNode, + inputIndex: number, + ) => { + let canConnect = true; + if (oldOnConnectOutput) { + canConnect = oldOnConnectOutput?.apply(this.node, [ + outputIndex, + inputType, + inputSlot, + inputNode, + inputIndex, + ]); + } + return ( + this.configuring || + rgthree.loadingApiJson || + (canConnect && !this.node.outputs[outputIndex]!.disabled) + ); + }; + + const onPropertyChanged = this.node.onPropertyChanged; + this.node.onPropertyChanged = (property: string, value: any, prevValue: any) => { + onPropertyChanged && onPropertyChanged.call(this, property, value, prevValue); + if (property === "combos_filter") { + this.refreshCombos(this.nodeData); + } + }; + + // Strip all widgets but prompt (we'll re-add them in refreshCombos) + // this.node.widgets.splice(1); + for (let i = this.node.widgets.length - 1; i >= 0; i--) { + if (this.shouldRemoveServerWidget(this.node.widgets[i]!)) { + this.node.widgets.splice(i, 1); + } + } + + this.refreshCombos(nodeData); + setTimeout(() => { + this.stabilizeInputsOutputs(); + }, 32); + } + + /** + * Cleans up optional out puts when we don't have the optional input. Purely a vanity function. + */ + onNodeConnectionsChange( + _type: number, + _slotIndex: number, + _isConnected: boolean, + _linkInfo: LLink, + _ioSlot: INodeOutputSlot | INodeInputSlot, + ) { + this.stabilizeInputsOutputs(); + } + + private stabilizeInputsOutputs() { + // If we are currently "configuring" then skip this stabilization. The connected nodes may + // not yet be configured. + if (this.configuring || rgthree.loadingApiJson) { + return; + } + // If our first input is connected, then we can show the proper output. + const clipLinked = this.node.inputs.some((i) => i.name.includes("clip") && !!i.link); + const modelLinked = this.node.inputs.some((i) => i.name.includes("model") && !!i.link); + for (const output of this.node.outputs) { + const type = (output.type as string).toLowerCase(); + if (type.includes("model")) { + output.disabled = !modelLinked; + } else if (type.includes("conditioning")) { + output.disabled = !clipLinked; + } else if (type.includes("clip")) { + output.disabled = !clipLinked; + } else if (type.includes("string")) { + // Our text prompt is always enabled, but let's color it so it stands out + // if the others are disabled. #7F7 is Litegraph's default. + output.color_off = "#7F7"; + output.color_on = "#7F7"; + } + if (output.disabled) { + // this.node.disconnectOutput(index); + } + } + } + + onFreshNodeDefs(event: CustomEvent) { + this.refreshCombos(event.detail[this.nodeData.name]); + } + + shouldRemoveServerWidget(widget: IWidget) { + return ( + widget.name?.startsWith("insert_") || + widget.name?.startsWith("target_") || + widget.name?.startsWith("crop_") || + widget.name?.startsWith("values_") + ); + } + + refreshCombos(nodeData: ComfyObjectInfo) { + this.nodeData = nodeData; + let filter: RegExp | null = null; + if (this.node.properties["combos_filter"]?.trim()) { + try { + filter = new RegExp(this.node.properties["combos_filter"].trim(), "i"); + } catch (e) { + console.error(`Could not parse "${filter}" for Regular Expression`, e); + filter = null; + } + } + + // Add the combo for hidden inputs of nodeData + let data = Object.assign( + {}, + this.nodeData.input?.optional || {}, + this.nodeData.input?.hidden || {}, + ); + + for (const [key, value] of Object.entries(data)) { + //Object.entries(this.nodeData.input?.hidden || {})) { + if (Array.isArray(value[0])) { + let values = value[0] as string[]; + if (key.startsWith("insert")) { + values = filter + ? values.filter( + (v, i) => i < 1 || (i == 1 && v.match(/^disable\s[a-z]/i)) || filter?.test(v), + ) + : values; + const shouldShow = + values.length > 2 || (values.length > 1 && !values[1]!.match(/^disable\s[a-z]/i)); + if (shouldShow) { + if (!this.combos[key]) { + this.combos[key] = this.node.addWidget( + "combo", + key, + values, + (selected) => { + if (selected !== values[0] && !selected.match(/^disable\s[a-z]/i)) { + // We wait a frame because if we use a keydown event to call, it'll wipe out + // the selection. + wait().then(() => { + if (key.includes("embedding")) { + this.insertSelectionText(`embedding:${selected}`); + } else if (key.includes("saved")) { + this.insertSelectionText( + this.combosValues[`values_${key}`]![values.indexOf(selected)]!, + ); + } else if (key.includes("lora")) { + this.insertSelectionText(``); + } + this.combos[key]!.value = values[0]; + }); + } + }, + { + values, + serialize: true, // Don't include this in prompt. + }, + ); + (this.combos[key]! as any).oldComputeSize = this.combos[key]!.computeSize; + let node = this.node; + this.combos[key]!.computeSize = function (width: number) { + const size = (this as any).oldComputeSize?.(width) || [ + width, + LiteGraph.NODE_WIDGET_HEIGHT, + ]; + if (this === node.widgets[node.widgets.length - 1]) { + size[1] += 10; + } + return size; + }; + } + this.combos[key]!.options!.values = values; + this.combos[key]!.value = values[0]; + } else if (!shouldShow && this.combos[key]) { + this.node.widgets.splice(this.node.widgets.indexOf(this.combos[key]!), 1); + delete this.combos[key]; + } + } else if (key.startsWith("values")) { + this.combosValues[key] = values; + } + } + } + } + + insertSelectionText(text: string) { + if (!this.promptEl) { + console.error("Asked to insert text, but no textbox found."); + return; + } + let prompt = this.promptEl.value; + // Use selectionEnd as the split; if we have highlighted text, then we likely don't want to + // overwrite it (we could have just deleted it more easily). + let first = prompt.substring(0, this.promptEl.selectionEnd).replace(/ +$/, ""); + first = first + (["\n"].includes(first[first.length - 1]!) ? "" : first.length ? " " : ""); + let second = prompt.substring(this.promptEl.selectionEnd).replace(/^ +/, ""); + second = (["\n"].includes(second[0]!) ? "" : second.length ? " " : "") + second; + this.promptEl.value = first + text + second; + this.promptEl.focus(); + this.promptEl.selectionStart = first.length; + this.promptEl.selectionEnd = first.length + text.length; + } + + /** + * Adds a keydown event listener to our prompt so we can see if we're using the + * ctrl/cmd + up/down arrows shortcut. This kind of competes with the core extension + * "Comfy.EditAttention" but since that only handles parenthesis and listens on window, we should + * be able to intercept and cancel the bubble if we're doing the same action within the lora tag. + */ + addAndHandleKeyboardLoraEditWeight() { + this.promptEl.addEventListener("keydown", (event: KeyboardEvent) => { + // If we're not doing a ctrl/cmd + arrow key, then bail. + if (!(event.key === "ArrowUp" || event.key === "ArrowDown")) return; + if (!event.ctrlKey && !event.metaKey) return; + // Unfortunately, we can't see Comfy.EditAttention delta in settings, so we hardcode to 0.01. + // We can acutally do better too, let's make it .1 by default, and .01 if also holding shift. + const delta = event.shiftKey ? 0.01 : 0.1; + + let start = this.promptEl.selectionStart; + let end = this.promptEl.selectionEnd; + let fullText = this.promptEl.value; + let selectedText = fullText.substring(start, end); + + // We don't care about fully rewriting Comfy.EditAttention, we just want to see if our + // selected text is a lora, which will always start with "") { + start -= 2; + end -= 2; + } + if (fullText[end - 1] == "<") { + start += 2; + end += 2; + } + while (!stopOn.includes(fullText[start]!) && start > 0) { + start--; + } + while (!stopOn.includes(fullText[end - 1]!) && end < fullText.length) { + end++; + } + selectedText = fullText.substring(start, end); + } + + // Bail if this isn't a lora. + if (!selectedText.startsWith("")) { + return; + } + + let weight = Number(selectedText.match(/:(-?\d*(\.\d*)?)>$/)?.[1]) ?? 1; + weight += event.key === "ArrowUp" ? delta : -delta; + const updatedText = selectedText.replace(/(:-?\d*(\.\d*)?)?>$/, `:${weight.toFixed(2)}>`); + + // Handle the new value and cancel the bubble so Comfy.EditAttention doesn't also try. + this.promptEl.setRangeText(updatedText, start, end, "select"); + event.preventDefault(); + event.stopPropagation(); + }); + } + + /** + * Patches over api.getNodeDefs in comfy's api.js to fire a custom event that we can listen to + * here and manually refresh our combos when a request comes in to fetch the node data; which + * only happens once at startup (but before custom nodes js runs), and then after clicking + * the "Refresh" button in the floating menu, which is what we care about. + */ + patchNodeRefresh() { + this.boundOnFreshNodeDefs = this.onFreshNodeDefs.bind(this); + api.addEventListener("fresh-node-defs", this.boundOnFreshNodeDefs as EventListener); + const oldNodeRemoved = this.node.onRemoved; + this.node.onRemoved = () => { + oldNodeRemoved?.call(this.node); + api.removeEventListener("fresh-node-defs", this.boundOnFreshNodeDefs as EventListener); + }; + } +} diff --git a/rgthree-comfy/src_web/comfyui/bookmark.ts b/rgthree-comfy/src_web/comfyui/bookmark.ts new file mode 100644 index 0000000000000000000000000000000000000000..5e8c6a41878babaa4b98e2f339f17eea26cc9bd7 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/bookmark.ts @@ -0,0 +1,163 @@ +import { app } from "scripts/app.js"; +import { RgthreeBaseVirtualNode } from "./base_node.js"; +import { SERVICE as KEY_EVENT_SERVICE } from "./services/key_events_services.js"; +import { NodeTypesString } from "./constants.js"; +import type { + LGraph, + LGraphCanvas, + INumberWidget, + LGraphNode, + Vector2, +} from "typings/litegraph.js"; +import { getClosestOrSelf, queryOne } from "rgthree/common/utils_dom.js"; + +/** + * A bookmark node. Can be placed anywhere in the workflow, and given a shortcut key that will + * navigate to that node, with it in the top-left corner. + */ +export class Bookmark extends RgthreeBaseVirtualNode { + static override type = NodeTypesString.BOOKMARK; + static override title = NodeTypesString.BOOKMARK; + override comfyClass = NodeTypesString.BOOKMARK; + + // Really silly, but Litegraph assumes we have at least one input/output... so we need to + // counteract it's computeSize calculation by offsetting the start. + static slot_start_y = -20; + + // LiteGraph adds mroe spacing than we want when calculating a nodes' `_collapsed_width`, so we'll + // override it with a setter and re-set it measured exactly as we want. + ___collapsed_width: number = 0; + + override isVirtualNode = true; + override serialize_widgets = true; + + //@ts-ignore - TS Doesn't like us overriding a property with accessors but, too bad. + override get _collapsed_width() { + return this.___collapsed_width; + } + + override set _collapsed_width(width: number) { + const canvas = app.canvas as LGraphCanvas; + const ctx = canvas.canvas.getContext("2d")!; + const oldFont = ctx.font; + ctx.font = canvas.title_text_font; + this.___collapsed_width = 40 + ctx.measureText(this.title).width; + ctx.font = oldFont; + } + + readonly keypressBound; + + constructor(title = Bookmark.title) { + super(title); + const nextShortcutChar = getNextShortcut(); + this.addWidget( + "text", + "shortcut_key", + nextShortcutChar, + (value: string, ...args) => { + value = value.trim()[0] || "1"; + }, + { + y: 8, + }, + ); + this.addWidget("number", "zoom", 1, (value: number) => {}, { + y: 8 + LiteGraph.NODE_WIDGET_HEIGHT + 4, + max: 2, + min: 0.5, + precision: 2, + }); + this.keypressBound = this.onKeypress.bind(this); + this.title = "🔖"; + this.onConstructed(); + } + + // override computeSize(out?: Vector2 | undefined): Vector2 { + // super.computeSize(out); + // const minHeight = (this.widgets?.length || 0) * (LiteGraph.NODE_WIDGET_HEIGHT + 4) + 16; + // this.size[1] = Math.max(minHeight, this.size[1]); + // } + + get shortcutKey(): string { + return this.widgets[0]?.value?.toLocaleLowerCase() ?? ""; + } + + override onAdded(graph: LGraph): void { + KEY_EVENT_SERVICE.addEventListener("keydown", this.keypressBound as EventListener); + } + + override onRemoved(): void { + KEY_EVENT_SERVICE.removeEventListener("keydown", this.keypressBound as EventListener); + } + + onKeypress(event: CustomEvent<{ originalEvent: KeyboardEvent }>) { + const originalEvent = event.detail.originalEvent; + const target = (originalEvent.target as HTMLElement)!; + if (getClosestOrSelf(target, 'input,textarea,[contenteditable="true"]')) { + return; + } + + // Only the shortcut keys are held down, otionally including "shift". + if (KEY_EVENT_SERVICE.areOnlyKeysDown(this.widgets[0]!.value, true)) { + this.canvasToBookmark(); + originalEvent.preventDefault(); + originalEvent.stopPropagation(); + } + } + + /** + * Called from LiteGraph's `processMouseDown` after it would invoke the input box for the + * shortcut_key, so we check if it exists and then add our own event listener so we can track the + * keys down for the user. + */ + override onMouseDown(event: MouseEvent, pos: Vector2, graphCanvas: LGraphCanvas): void { + const input = queryOne(".graphdialog > input.value"); + if (input && input.value === this.widgets[0]?.value) { + input.addEventListener("keydown", (e) => { + // ComfyUI swallows keydown on inputs, so we need to call out to rgthree to use downkeys. + KEY_EVENT_SERVICE.handleKeyDownOrUp(e); + e.preventDefault(); + e.stopPropagation(); + input.value = Object.keys(KEY_EVENT_SERVICE.downKeys).join(" + "); + }); + } + } + + canvasToBookmark() { + const canvas = app.canvas as LGraphCanvas; + // ComfyUI seemed to break us again, but couldn't repro. No reason to not check, I guess. + // https://github.com/rgthree/rgthree-comfy/issues/71 + if (canvas?.ds?.offset) { + canvas.ds.offset[0] = -this.pos[0] + 16; + canvas.ds.offset[1] = -this.pos[1] + 40; + } + if (canvas?.ds?.scale != null) { + canvas.ds.scale = Number(this.widgets[1]!.value || 1); + } + canvas.setDirty(true, true); + } +} + +app.registerExtension({ + name: "rgthree.Bookmark", + registerCustomNodes() { + Bookmark.setUp(); + }, +}); + +function isBookmark(node: LGraphNode): node is Bookmark { + return node.type === NodeTypesString.BOOKMARK; +} + +function getExistingShortcuts() { + const graph: LGraph = app.graph; + const bookmarkNodes = graph._nodes.filter(isBookmark); + const usedShortcuts = new Set(bookmarkNodes.map((n) => n.shortcutKey)); + return usedShortcuts; +} + +const SHORTCUT_DEFAULTS = "1234567890abcdefghijklmnopqrstuvwxyz".split(""); +function getNextShortcut() { + const existingShortcuts = getExistingShortcuts(); + return SHORTCUT_DEFAULTS.find((char) => !existingShortcuts.has(char)) ?? "1"; +} diff --git a/rgthree-comfy/src_web/comfyui/bypasser.ts b/rgthree-comfy/src_web/comfyui/bypasser.ts new file mode 100644 index 0000000000000000000000000000000000000000..c75125a40919ec43fc9c00dab6e68bbdcfd018ef --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/bypasser.ts @@ -0,0 +1,51 @@ +import { app } from "scripts/app.js"; +import { BaseNodeModeChanger } from "./base_node_mode_changer.js"; +import { NodeTypesString } from "./constants.js"; +import type { LGraphNode } from "typings/litegraph.js"; + +const MODE_BYPASS = 4; +const MODE_ALWAYS = 0; + +class BypasserNode extends BaseNodeModeChanger { + static override exposedActions = ["Bypass all", "Enable all", "Toggle all"]; + + static override type = NodeTypesString.FAST_BYPASSER; + static override title = NodeTypesString.FAST_BYPASSER; + override comfyClass = NodeTypesString.FAST_BYPASSER; + + override readonly modeOn = MODE_ALWAYS; + override readonly modeOff = MODE_BYPASS; + + constructor(title = BypasserNode.title) { + super(title); + this.onConstructed(); + } + + override async handleAction(action: string) { + if (action === "Bypass all") { + for (const widget of this.widgets) { + this.forceWidgetOff(widget, true); + } + } else if (action === "Enable all") { + for (const widget of this.widgets) { + this.forceWidgetOn(widget, true); + } + } else if (action === "Toggle all") { + for (const widget of this.widgets) { + this.forceWidgetToggle(widget, true); + } + } + } +} + +app.registerExtension({ + name: "rgthree.Bypasser", + registerCustomNodes() { + BypasserNode.setUp(); + }, + loadedGraphNode(node: LGraphNode) { + if (node.type == BypasserNode.title) { + (node as any)._tempWidth = node.size[0]; + } + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/comfy_ui_bar.ts b/rgthree-comfy/src_web/comfyui/comfy_ui_bar.ts new file mode 100644 index 0000000000000000000000000000000000000000..bca36d975bb804c8483865664e0c9b56c4d32cce --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/comfy_ui_bar.ts @@ -0,0 +1,120 @@ +import { app } from "scripts/app.js"; +import { ComfyButtonGroup } from "scripts/ui/components/buttonGroup.js"; +import { ComfyButton } from "scripts/ui/components/button.js"; +import { iconGear, iconStarFilled, logoRgthree } from "rgthree/common/media/svgs.js"; +import { createElement, empty, queryOne } from "rgthree/common/utils_dom.js"; +import { SERVICE as BOOKMARKS_SERVICE } from "./services/bookmarks_services.js"; +import { SERVICE as CONFIG_SERVICE } from "./services/config_service.js"; +import { ComfyPopup } from "scripts/ui/components/popup.js"; +import { RgthreeConfigDialog } from "./config.js"; + +let rgthreeButtonGroup: ComfyButtonGroup | null = null; + +function addRgthreeTopBarButtons() { + if (!CONFIG_SERVICE.getFeatureValue("comfy_top_bar_menu.enabled")) { + if (rgthreeButtonGroup?.element?.parentElement) { + rgthreeButtonGroup.element.parentElement.removeChild(rgthreeButtonGroup.element); + } + return; + } else if (rgthreeButtonGroup) { + app.menu?.settingsGroup.element.before(rgthreeButtonGroup.element); + return; + } + + const buttons = []; + + const rgthreeButton = new ComfyButton({ + icon: "rgthree", + tooltip: "rgthree-comfy", + // content: 'rgthree-comfy', + app, + enabled: true, + classList: "comfyui-button comfyui-menu-mobile-collapse primary", + }); + buttons.push(rgthreeButton); + rgthreeButton.iconElement.style.width = "1.2rem"; + rgthreeButton.iconElement.innerHTML = logoRgthree; + rgthreeButton.withPopup( + new ComfyPopup( + { target: rgthreeButton.element, classList: "rgthree-top-menu" }, + createElement("menu", { + children: [ + createElement("li", { + child: createElement("button.rgthree-button-reset", { + html: iconGear + "Settings (rgthree-comfy)", + onclick: () => new RgthreeConfigDialog().show(), + }), + }), + createElement("li", { + child: createElement("button.rgthree-button-reset", { + html: iconStarFilled + "Star on Github", + onclick: () => window.open("https://github.com/rgthree/rgthree-comfy", "_blank"), + }), + }), + ], + }), + ), + "click", + ); + + if (CONFIG_SERVICE.getFeatureValue("comfy_top_bar_menu.button_bookmarks.enabled")) { + const bookmarksListEl = createElement("menu"); + bookmarksListEl.appendChild( + createElement("li.rgthree-message", { + child: createElement("span", { text: "No bookmarks in current workflow." }), + }), + ); + const bookmarksButton = new ComfyButton({ + icon: "bookmark", + tooltip: "Workflow Bookmarks (rgthree-comfy)", + app, + }); + const bookmarksPopup = new ComfyPopup( + { target: bookmarksButton.element, classList: "rgthree-top-menu" }, + bookmarksListEl, + ); + bookmarksPopup.addEventListener("open", () => { + const bookmarks = BOOKMARKS_SERVICE.getCurrentBookmarks(); + empty(bookmarksListEl); + if (bookmarks.length) { + for (const b of bookmarks) { + bookmarksListEl.appendChild( + createElement("li", { + child: createElement("button.rgthree-button-reset", { + text: `[${b.shortcutKey}] ${b.title}`, + onclick: () => { + b.canvasToBookmark(); + }, + }), + }), + ); + } + } else { + bookmarksListEl.appendChild( + createElement("li.rgthree-message", { + child: createElement("span", { text: "No bookmarks in current workflow." }), + }), + ); + } + bookmarksPopup.update(); + }); + bookmarksButton.withPopup(bookmarksPopup, "hover"); + buttons.push(bookmarksButton); + } + + rgthreeButtonGroup = new ComfyButtonGroup(...buttons); + app.menu?.settingsGroup.element.before(rgthreeButtonGroup.element); +} + +app.registerExtension({ + name: "rgthree.TopMenu", + async setup() { + addRgthreeTopBarButtons(); + + CONFIG_SERVICE.addEventListener("config-change", ((e: CustomEvent) => { + if (e.detail?.key?.includes("features.comfy_top_bar_menu")) { + addRgthreeTopBarButtons(); + } + }) as EventListener); + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/config.ts b/rgthree-comfy/src_web/comfyui/config.ts new file mode 100644 index 0000000000000000000000000000000000000000..8999dbe9012f38a8b1d7ac37dc34de4789b86ead --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/config.ts @@ -0,0 +1,376 @@ +import { app } from "scripts/app.js"; +import { RgthreeDialog, RgthreeDialogOptions } from "rgthree/common/dialog.js"; +import { createElement as $el, query as $$ } from "rgthree/common/utils_dom.js"; +import { checkmark, logoRgthree } from "rgthree/common/media/svgs.js"; +import { LogLevel, rgthree } from "./rgthree.js"; +import { SERVICE as CONFIG_SERVICE } from "./services/config_service.js"; + +/** Types of config used as a hint for the form handling. */ +enum ConfigType { + UNKNOWN, + BOOLEAN, + STRING, + NUMBER, + ARRAY, +} + +const TYPE_TO_STRING = { + [ConfigType.UNKNOWN]: "unknown", + [ConfigType.BOOLEAN]: "boolean", + [ConfigType.STRING]: "string", + [ConfigType.NUMBER]: "number", + [ConfigType.ARRAY]: "array", +}; + +type ConfigurationSchema = { + key: string; + type: ConfigType; + label: string; + options?: string[] | number[] | ConfigurationSchemaOption[]; + description?: string; + subconfig?: ConfigurationSchema[]; + isDevOnly?: boolean; + onSave?: (value: any) => void; +}; + +type ConfigurationSchemaOption = { value: any; label: string }; + +/** + * A static schema of sorts to layout options found in the config. + */ +const CONFIGURABLE: { [key: string]: ConfigurationSchema[] } = { + features: [ + { + key: "features.patch_recursive_execution", + type: ConfigType.BOOLEAN, + label: "Optimize ComfyUI's Execution", + description: + "Patches ComfyUI's backend execution making complex workflows 1000's of times faster." + + "
⚠️ Disable if execution seems broken due to forward ComfyUI changes.", + }, + { + key: "features.progress_bar.enabled", + type: ConfigType.BOOLEAN, + label: "Prompt Progress Bar", + description: `Shows a minimal progress bar for nodes and steps at the top of the app.`, + subconfig: [ + { + key: "features.progress_bar.height", + type: ConfigType.NUMBER, + label: "Height of the bar", + }, + { + key: "features.progress_bar.position", + type: ConfigType.STRING, + label: "Position at top or bottom of window", + options: ["top", "bottom"], + }, + ], + }, + { + key: "features.import_individual_nodes.enabled", + type: ConfigType.BOOLEAN, + label: "Import Individual Nodes Widgets", + description: + "Dragging & Dropping a similar image/JSON workflow onto (most) current workflow nodes" + + "will allow you to import that workflow's node's widgets when it has the same " + + "id and type. This is useful when you have several images and you'd like to import just " + + "one part of a previous iteration, like a seed, or prompt.", + }, + ], + menus: [ + { + key: "features.comfy_top_bar_menu.enabled", + type: ConfigType.BOOLEAN, + label: "Enable Top Bar Menu", + description: + "Have quick access from ComfyUI's new top bar to rgthree-comfy bookmarks, settings " + + "(and more to come).", + }, + { + key: "features.menu_queue_selected_nodes", + type: ConfigType.BOOLEAN, + label: "Show 'Queue Selected Output Nodes'", + description: + "Will show a menu item in the right-click context menus to queue (only) the selected " + + "output nodes.", + }, + { + key: "features.menu_auto_nest.subdirs", + type: ConfigType.BOOLEAN, + label: "Auto Nest Subdirectories in Menus", + description: + "When a large, flat list of values contain sub-directories, auto nest them. (Like, for " + + "a large list of checkpoints).", + subconfig: [ + { + key: "features.menu_auto_nest.threshold", + type: ConfigType.NUMBER, + label: "Number of items needed to trigger nesting.", + }, + ], + }, + { + key: "features.menu_bookmarks.enabled", + type: ConfigType.BOOLEAN, + label: "Show Bookmarks in context menu", + description: "Will list bookmarks in the rgthree-comfy right-click context menu.", + }, + ], + groups: [ + { + key: "features.group_header_fast_toggle.enabled", + type: ConfigType.BOOLEAN, + label: "Show fast toggles in Group Headers", + description: "Show quick toggles in Groups' Headers to quickly mute and/or bypass.", + subconfig: [ + { + key: "features.group_header_fast_toggle.toggles", + type: ConfigType.ARRAY, + label: "Which toggles to show.", + options: [ + { value: ["mute"], label: "mute only" }, + { value: ["bypass"], label: "bypass only" }, + { value: ["mute", "bypass"], label: "mute and bypass" }, + ], + }, + { + key: "features.group_header_fast_toggle.show", + type: ConfigType.STRING, + label: "When to show them.", + options: [ + { value: "hover", label: "on hover" }, + { value: "always", label: "always" }, + ], + }, + ], + }, + ], + advanced: [ + { + key: "features.show_alerts_for_corrupt_workflows", + type: ConfigType.BOOLEAN, + label: "Detect Corrupt Workflows", + description: + "Will show a message at the top of the screen when loading a workflow that has " + + "corrupt linking data.", + }, + { + key: "log_level", + type: ConfigType.STRING, + label: "Log level for browser dev console.", + description: + "Further down the list, the more verbose logs to the console will be. For instance, " + + "selecting 'IMPORTANT' means only important message will be logged to the browser " + + "console, while selecting 'WARN' will log all messages at or higher than WARN, including " + + "'ERROR' and 'IMPORTANT' etc.", + options: ["IMPORTANT", "ERROR", "WARN", "INFO", "DEBUG", "DEV"], + isDevOnly: true, + onSave: function (value: LogLevel) { + rgthree.setLogLevel(value); + }, + }, + { + key: "features.invoke_extensions_async.node_created", + type: ConfigType.BOOLEAN, + label: "Allow other extensions to call nodeCreated on rgthree-nodes.", + isDevOnly: true, + description: + "Do not disable unless you are having trouble (and then file an issue at rgthree-comfy)." + + "Prior to Apr 2024 it was not possible for other extensions to invoke their nodeCreated " + + "event on some rgthree-comfy nodes. Now it's possible and this option is only here in " + + "for easy if something is wrong.", + }, + ], +}; + +/** + * Creates a new fieldrow for main or sub configuration items. + */ +function fieldrow(item: ConfigurationSchema) { + const initialValue = CONFIG_SERVICE.getConfigValue(item.key); + const container = $el(`div.fieldrow.-type-${TYPE_TO_STRING[item.type]}`, { + dataset: { + name: item.key, + initial: initialValue, + type: item.type, + }, + }); + + $el(`label[for="${item.key}"]`, { + children: [ + $el(`span[text="${item.label}"]`), + item.description ? $el("small", { html: item.description }) : null, + ], + parent: container, + }); + + let input; + if (item.options?.length) { + input = $el(`select[id="${item.key}"]`, { + parent: container, + children: item.options.map((o) => { + const label = (o as ConfigurationSchemaOption).label || String(o); + const value = (o as ConfigurationSchemaOption).value || o; + const valueSerialized = JSON.stringify({ value: value }); + return $el(`option[value="${valueSerialized}"]`, { + text: label, + selected: valueSerialized === JSON.stringify({ value: initialValue }), + }); + }), + }); + } else if (item.type === ConfigType.BOOLEAN) { + container.classList.toggle("-checked", !!initialValue); + input = $el(`input[type="checkbox"][id="${item.key}"]`, { + parent: container, + checked: initialValue, + }); + } else { + input = $el(`input[id="${item.key}"]`, { + parent: container, + value: initialValue, + }); + } + $el("div.fieldrow-value", { children: [input], parent: container }); + return container; +} + +/** + * A dialog to edit rgthree-comfy settings and config. + */ +export class RgthreeConfigDialog extends RgthreeDialog { + constructor() { + const content = $el("div"); + + content.appendChild(RgthreeConfigDialog.buildFieldset(CONFIGURABLE["features"]!, "Features")); + content.appendChild(RgthreeConfigDialog.buildFieldset(CONFIGURABLE["menus"]!, "Menus")); + content.appendChild(RgthreeConfigDialog.buildFieldset(CONFIGURABLE["groups"]!, "Groups")); + content.appendChild(RgthreeConfigDialog.buildFieldset(CONFIGURABLE["advanced"]!, "Advanced")); + + content.addEventListener("input", (e) => { + const changed = this.getChangedFormData(); + ($$(".save-button", this.element)[0] as HTMLButtonElement).disabled = + !Object.keys(changed).length; + }); + content.addEventListener("change", (e) => { + const changed = this.getChangedFormData(); + ($$(".save-button", this.element)[0] as HTMLButtonElement).disabled = + !Object.keys(changed).length; + }); + + const dialogOptions: RgthreeDialogOptions = { + class: "-iconed -settings", + title: logoRgthree + `

Settings - rgthree-comfy

`, + content, + onBeforeClose: () => { + const changed = this.getChangedFormData(); + if (Object.keys(changed).length) { + return confirm("Looks like there are unsaved changes. Are you sure you want close?"); + } + return true; + }, + buttons: [ + { + label: "Save", + disabled: true, + className: "rgthree-button save-button -blue", + callback: async (e) => { + const changed = this.getChangedFormData(); + if (!Object.keys(changed).length) { + this.close(); + return; + } + const success = await CONFIG_SERVICE.setConfigValues(changed); + if (success) { + for (const key of Object.keys(changed)) { + Object.values(CONFIGURABLE) + .flat() + .find((f) => f.key === key) + ?.onSave?.(changed[key]); + } + this.close(); + rgthree.showMessage({ + id: "config-success", + message: `${checkmark} Successfully saved rgthree-comfy settings!`, + timeout: 4000, + }); + ($$(".save-button", this.element)[0] as HTMLButtonElement).disabled = true; + } else { + alert("There was an error saving rgthree-comfy configuration."); + } + }, + }, + ], + }; + super(dialogOptions); + } + + private static buildFieldset(datas: ConfigurationSchema[], label: string) { + const fieldset = $el(`fieldset`, { children: [$el(`legend[text="${label}"]`)] }); + for (const data of datas) { + if (data.isDevOnly && !rgthree.isDevMode()) { + continue; + } + const container = $el("div.formrow"); + container.appendChild(fieldrow(data)); + + if (data.subconfig) { + for (const subfeature of data.subconfig) { + container.appendChild(fieldrow(subfeature)); + } + } + fieldset.appendChild(container); + } + return fieldset; + } + + getChangedFormData() { + return $$("[data-name]", this.contentElement).reduce((acc: { [key: string]: any }, el) => { + const name = el.dataset["name"]!; + const type = el.dataset["type"]!; + const initialValue = CONFIG_SERVICE.getConfigValue(name); + let currentValueEl = $$("input, textarea, select", el)[0] as HTMLInputElement; + let currentValue: any = null; + if (type === String(ConfigType.BOOLEAN)) { + currentValue = currentValueEl.checked; + // Not sure I like this side effect in here, but it's easy to just do it now. + el.classList.toggle("-checked", currentValue); + } else { + currentValue = currentValueEl?.value; + if (currentValueEl.nodeName === "SELECT") { + currentValue = JSON.parse(currentValue).value; + } else if (type === String(ConfigType.NUMBER)) { + currentValue = Number(currentValue) || initialValue; + } + } + if (JSON.stringify(currentValue) !== JSON.stringify(initialValue)) { + acc[name] = currentValue; + } + return acc; + }, {}); + } +} + +app.ui.settings.addSetting({ + id: "rgthree.config", + name: "Open rgthree-comfy config", + type: () => { + // Adds a row to open the dialog from the ComfyUI settings. + return $el("tr.rgthree-comfyui-settings-row", { + children: [ + $el("td", { + child: `
${logoRgthree} [rgthree-comfy] configuration / settings
`, + }), + $el("td", { + child: $el('button.rgthree-button.-blue[text="rgthree-comfy settings"]', { + events: { + click: (e: PointerEvent) => { + new RgthreeConfigDialog().show(); + }, + }, + }), + }), + ], + }); + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/constants.ts b/rgthree-comfy/src_web/comfyui/constants.ts new file mode 100644 index 0000000000000000000000000000000000000000..4b18471830ec071fb628b5302b99402af5ef40d5 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/constants.ts @@ -0,0 +1,62 @@ +import {SERVICE as CONFIG_SERVICE} from "./services/config_service.js"; + +export function addRgthree(str: string) { + return str + " (rgthree)"; +} + +export function stripRgthree(str: string) { + return str.replace(/\s*\(rgthree\)$/, ""); +} + +export const NodeTypesString = { + ANY_SWITCH: addRgthree("Any Switch"), + CONTEXT: addRgthree("Context"), + CONTEXT_BIG: addRgthree("Context Big"), + CONTEXT_SWITCH: addRgthree("Context Switch"), + CONTEXT_SWITCH_BIG: addRgthree("Context Switch Big"), + CONTEXT_MERGE: addRgthree("Context Merge"), + CONTEXT_MERGE_BIG: addRgthree("Context Merge Big"), + DYNAMIC_CONTEXT: addRgthree("Dynamic Context"), + DYNAMIC_CONTEXT_SWITCH: addRgthree("Dynamic Context Switch"), + DISPLAY_ANY: addRgthree("Display Any"), + NODE_MODE_RELAY: addRgthree("Mute / Bypass Relay"), + NODE_MODE_REPEATER: addRgthree("Mute / Bypass Repeater"), + FAST_MUTER: addRgthree("Fast Muter"), + FAST_BYPASSER: addRgthree("Fast Bypasser"), + FAST_GROUPS_MUTER: addRgthree("Fast Groups Muter"), + FAST_GROUPS_BYPASSER: addRgthree("Fast Groups Bypasser"), + FAST_ACTIONS_BUTTON: addRgthree("Fast Actions Button"), + LABEL: addRgthree("Label"), + POWER_PROMPT: addRgthree("Power Prompt"), + POWER_PROMPT_SIMPLE: addRgthree("Power Prompt - Simple"), + SDXL_EMPTY_LATENT_IMAGE: addRgthree("SDXL Empty Latent Image"), + SDXL_POWER_PROMPT_POSITIVE: addRgthree("SDXL Power Prompt - Positive"), + SDXL_POWER_PROMPT_NEGATIVE: addRgthree("SDXL Power Prompt - Simple / Negative"), + POWER_LORA_LOADER: addRgthree("Power Lora Loader"), + KSAMPLER_CONFIG: addRgthree("KSampler Config"), + NODE_COLLECTOR: addRgthree("Node Collector"), + REROUTE: addRgthree("Reroute"), + RANDOM_UNMUTER: addRgthree("Random Unmuter"), + SEED: addRgthree("Seed"), + BOOKMARK: addRgthree("Bookmark"), + IMAGE_COMPARER: addRgthree("Image Comparer"), + IMAGE_INSET_CROP: addRgthree("Image Inset Crop"), +}; + +/** + * Gets the list of nodes from NoteTypeString above, filtering any that are not applicable. + */ +export function getNodeTypeStrings() { + return Object.values(NodeTypesString) + .map((i) => stripRgthree(i)) + .filter((i) => { + if ( + i.startsWith("Dynamic Context") && + !CONFIG_SERVICE.getConfigValue("unreleased.dynamic_context.enabled") + ) { + return false; + } + return true; + }) + .sort(); +} diff --git a/rgthree-comfy/src_web/comfyui/context.ts b/rgthree-comfy/src_web/comfyui/context.ts new file mode 100644 index 0000000000000000000000000000000000000000..5e45870c82b5173d8b6777fdaf6a4214b7676361 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/context.ts @@ -0,0 +1,483 @@ +import type { + INodeInputSlot, + INodeOutputSlot, + LGraphCanvas as TLGraphCanvas, + LGraphNode as TLGraphNode, + LLink, +} from "typings/litegraph.js"; +import type { ComfyNodeConstructor, ComfyObjectInfo } from "typings/comfy.js"; +import { app } from "scripts/app.js"; +import { + IoDirection, + addConnectionLayoutSupport, + addMenuItem, + matchLocalSlotsToServer, + replaceNode, +} from "./utils.js"; +import { RgthreeBaseServerNode } from "./base_node.js"; +import { SERVICE as KEY_EVENT_SERVICE } from "./services/key_events_services.js"; +import { RgthreeBaseServerNodeConstructor } from "typings/rgthree.js"; +import { debounce, wait } from "rgthree/common/shared_utils.js"; +import { removeUnusedInputsFromEnd } from "./utils_inputs_outputs.js"; +import { NodeTypesString } from "./constants.js"; + +/** + * Takes a non-context node and determins for its input or output slot, if there is a valid + * connection for an opposite context output or input slot. + */ +function findMatchingIndexByTypeOrName( + otherNode: TLGraphNode, + otherSlot: INodeInputSlot | INodeOutputSlot, + ctxSlots: INodeInputSlot[] | INodeOutputSlot[], +) { + const otherNodeType = (otherNode.type || "").toUpperCase(); + const otherNodeName = (otherNode.title || "").toUpperCase(); + let otherSlotType = otherSlot.type as string; + if (Array.isArray(otherSlotType) || otherSlotType.includes(",")) { + otherSlotType = "COMBO"; + } + const otherSlotName = otherSlot.name.toUpperCase().replace("OPT_", "").replace("_NAME", ""); + let ctxSlotIndex = -1; + if (["CONDITIONING", "INT", "STRING", "FLOAT", "COMBO"].includes(otherSlotType)) { + ctxSlotIndex = ctxSlots.findIndex((ctxSlot) => { + const ctxSlotName = ctxSlot.name.toUpperCase().replace("OPT_", "").replace("_NAME", ""); + let ctxSlotType = ctxSlot.type as string; + if (Array.isArray(ctxSlotType) || ctxSlotType.includes(",")) { + ctxSlotType = "COMBO"; + } + if (ctxSlotType !== otherSlotType) { + return false; + } + // Straightforward matches. + if ( + ctxSlotName === otherSlotName || + (ctxSlotName === "SEED" && otherSlotName.includes("SEED")) || + (ctxSlotName === "STEP_REFINER" && otherSlotName.includes("AT_STEP")) || + (ctxSlotName === "STEP_REFINER" && otherSlotName.includes("REFINER_STEP")) + ) { + return true; + } + // If postive other node, try to match conditining and text. + if ( + (otherNodeType.includes("POSITIVE") || otherNodeName.includes("POSITIVE")) && + ((ctxSlotName === "POSITIVE" && otherSlotType === "CONDITIONING") || + (ctxSlotName === "TEXT_POS_G" && otherSlotName.includes("TEXT_G")) || + (ctxSlotName === "TEXT_POS_L" && otherSlotName.includes("TEXT_L"))) + ) { + return true; + } + if ( + (otherNodeType.includes("NEGATIVE") || otherNodeName.includes("NEGATIVE")) && + ((ctxSlotName === "NEGATIVE" && otherSlotType === "CONDITIONING") || + (ctxSlotName === "TEXT_NEG_G" && otherSlotName.includes("TEXT_G")) || + (ctxSlotName === "TEXT_NEG_L" && otherSlotName.includes("TEXT_L"))) + ) { + return true; + } + return false; + }); + } else { + ctxSlotIndex = ctxSlots.map((s) => s.type).indexOf(otherSlotType); + } + return ctxSlotIndex; +} + +/** + * A Base Context node for other context based nodes to extend. + */ +export class BaseContextNode extends RgthreeBaseServerNode { + constructor(title: string) { + super(title); + } + + // LiteGraph adds more spacing than we want when calculating a nodes' `_collapsed_width`, so we'll + // override it with a setter and re-set it measured exactly as we want. + ___collapsed_width: number = 0; + + //@ts-ignore - TS Doesn't like us overriding a property with accessors but, too bad. + override get _collapsed_width() { + return this.___collapsed_width; + } + + override set _collapsed_width(width: number) { + const canvas = app.canvas as TLGraphCanvas; + const ctx = canvas.canvas.getContext("2d")!; + const oldFont = ctx.font; + ctx.font = canvas.title_text_font; + let title = this.title.trim(); + this.___collapsed_width = 30 + (title ? 10 + ctx.measureText(title).width : 0); + ctx.font = oldFont; + } + + override connectByType( + slot: string | number, + sourceNode: TLGraphNode, + sourceSlotType: string, + optsIn: string, + ): T | null { + let canConnect = + super.connectByType && + super.connectByType.call(this, slot, sourceNode, sourceSlotType, optsIn); + if (!super.connectByType) { + canConnect = LGraphNode.prototype.connectByType.call( + this, + slot, + sourceNode, + sourceSlotType, + optsIn, + ); + } + if (!canConnect && slot === 0) { + const ctrlKey = KEY_EVENT_SERVICE.ctrlKey; + // Okay, we've dragged a context and it can't connect.. let's connect all the other nodes. + // Unfortunately, we don't know which are null now, so we'll just connect any that are + // not already connected. + for (const [index, input] of (sourceNode.inputs || []).entries()) { + if (input.link && !ctrlKey) { + continue; + } + const thisOutputSlot = findMatchingIndexByTypeOrName(sourceNode, input, this.outputs); + if (thisOutputSlot > -1) { + this.connect(thisOutputSlot, sourceNode, index); + } + } + } + return null; + } + + override connectByTypeOutput( + slot: string | number, + sourceNode: TLGraphNode, + sourceSlotType: string, + optsIn: string, + ): T | null { + let canConnect = + super.connectByTypeOutput && + super.connectByTypeOutput.call(this, slot, sourceNode, sourceSlotType, optsIn); + if (!super.connectByType) { + canConnect = LGraphNode.prototype.connectByTypeOutput.call( + this, + slot, + sourceNode, + sourceSlotType, + optsIn, + ); + } + if (!canConnect && slot === 0) { + const ctrlKey = KEY_EVENT_SERVICE.ctrlKey; + // Okay, we've dragged a context and it can't connect.. let's connect all the other nodes. + // Unfortunately, we don't know which are null now, so we'll just connect any that are + // not already connected. + for (const [index, output] of (sourceNode.outputs || []).entries()) { + if (output.links?.length && !ctrlKey) { + continue; + } + const thisInputSlot = findMatchingIndexByTypeOrName(sourceNode, output, this.inputs); + if (thisInputSlot > -1) { + sourceNode.connect(index, this, thisInputSlot); + } + } + } + return null; + } + + static override setUp( + comfyClass: ComfyNodeConstructor, + nodeData: ComfyObjectInfo, + ctxClass: RgthreeBaseServerNodeConstructor, + ) { + RgthreeBaseServerNode.registerForOverride(comfyClass, nodeData, ctxClass); + // [🤮] ComfyUI only adds "required" inputs to the outputs list when dragging an output to + // empty space, but since RGTHREE_CONTEXT is optional, it doesn't get added to the menu because + // ...of course. So, we'll manually add it. Of course, we also have to do this in a timeout + // because ComfyUI clears out `LiteGraph.slot_types_default_out` in its own 'Comfy.SlotDefaults' + // extension and we need to wait for that to happen. + wait(500).then(() => { + LiteGraph.slot_types_default_out["RGTHREE_CONTEXT"] = + LiteGraph.slot_types_default_out["RGTHREE_CONTEXT"] || []; + LiteGraph.slot_types_default_out["RGTHREE_CONTEXT"].push(comfyClass.comfyClass); + }); + } + + static override onRegisteredForOverride(comfyClass: any, ctxClass: any) { + addConnectionLayoutSupport(ctxClass, app, [ + ["Left", "Right"], + ["Right", "Left"], + ]); + setTimeout(() => { + ctxClass.category = comfyClass.category; + }); + } +} + +/** + * The original Context node. + */ +class ContextNode extends BaseContextNode { + static override title = NodeTypesString.CONTEXT; + static override type = NodeTypesString.CONTEXT; + static comfyClass = NodeTypesString.CONTEXT; + + constructor(title = ContextNode.title) { + super(title); + } + + static override setUp(comfyClass: ComfyNodeConstructor, nodeData: ComfyObjectInfo) { + BaseContextNode.setUp(comfyClass, nodeData, ContextNode); + } + + static override onRegisteredForOverride(comfyClass: any, ctxClass: any) { + BaseContextNode.onRegisteredForOverride(comfyClass, ctxClass); + addMenuItem(ContextNode, app, { + name: "Convert To Context Big", + callback: (node) => { + replaceNode(node, ContextBigNode.type); + }, + }); + } +} + +/** + * The Context Big node. + */ +class ContextBigNode extends BaseContextNode { + static override title = NodeTypesString.CONTEXT_BIG; + static override type = NodeTypesString.CONTEXT_BIG; + static comfyClass = NodeTypesString.CONTEXT_BIG; + + constructor(title = ContextBigNode.title) { + super(title); + } + + static override setUp(comfyClass: ComfyNodeConstructor, nodeData: ComfyObjectInfo) { + BaseContextNode.setUp(comfyClass, nodeData, ContextBigNode); + } + + static override onRegisteredForOverride(comfyClass: any, ctxClass: any) { + BaseContextNode.onRegisteredForOverride(comfyClass, ctxClass); + addMenuItem(ContextBigNode, app, { + name: "Convert To Context (Original)", + callback: (node) => { + replaceNode(node, ContextNode.type); + }, + }); + } +} + +/** + * A base node for Context Switche nodes and Context Merges nodes that will always add another empty + * ctx input, no less than five. + */ +class BaseContextMultiCtxInputNode extends BaseContextNode { + private stabilizeBound = this.stabilize.bind(this); + + constructor(title: string) { + super(title); + // Adding five. Note, configure will add as many as was in the stored workflow automatically. + this.addContextInput(5); + } + + private addContextInput(num = 1) { + for (let i = 0; i < num; i++) { + this.addInput(`ctx_${String(this.inputs.length + 1).padStart(2, "0")}`, "RGTHREE_CONTEXT"); + } + } + + override onConnectionsChange( + type: number, + slotIndex: number, + isConnected: boolean, + link: LLink, + ioSlot: INodeInputSlot | INodeOutputSlot, + ): void { + super.onConnectionsChange?.apply(this, [...arguments] as any); + if (type === LiteGraph.INPUT) { + this.scheduleStabilize(); + } + } + + private scheduleStabilize(ms = 64) { + return debounce(this.stabilizeBound, 64); + } + + /** + * Stabilizes the inputs; removing any disconnected ones from the bottom, then adding an empty + * one to the end so we always have one empty one to expand. + */ + private stabilize() { + removeUnusedInputsFromEnd(this, 4); + this.addContextInput(); + } +} + +/** + * The Context Switch (original) node. + */ +class ContextSwitchNode extends BaseContextMultiCtxInputNode { + static override title = NodeTypesString.CONTEXT_SWITCH; + static override type = NodeTypesString.CONTEXT_SWITCH; + static comfyClass = NodeTypesString.CONTEXT_SWITCH; + + constructor(title = ContextSwitchNode.title) { + super(title); + } + + static override setUp(comfyClass: ComfyNodeConstructor, nodeData: ComfyObjectInfo) { + BaseContextNode.setUp(comfyClass, nodeData, ContextSwitchNode); + } + + static override onRegisteredForOverride(comfyClass: any, ctxClass: any) { + BaseContextNode.onRegisteredForOverride(comfyClass, ctxClass); + addMenuItem(ContextSwitchNode, app, { + name: "Convert To Context Switch Big", + callback: (node) => { + replaceNode(node, ContextSwitchBigNode.type); + }, + }); + } +} + +/** + * The Context Switch Big node. + */ +class ContextSwitchBigNode extends BaseContextMultiCtxInputNode { + static override title = NodeTypesString.CONTEXT_SWITCH_BIG; + static override type = NodeTypesString.CONTEXT_SWITCH_BIG; + static comfyClass = NodeTypesString.CONTEXT_SWITCH_BIG; + + constructor(title = ContextSwitchBigNode.title) { + super(title); + } + + static override setUp(comfyClass: ComfyNodeConstructor, nodeData: ComfyObjectInfo) { + BaseContextNode.setUp(comfyClass, nodeData, ContextSwitchBigNode); + } + + static override onRegisteredForOverride(comfyClass: any, ctxClass: any) { + BaseContextNode.onRegisteredForOverride(comfyClass, ctxClass); + addMenuItem(ContextSwitchBigNode, app, { + name: "Convert To Context Switch", + callback: (node) => { + replaceNode(node, ContextSwitchNode.type); + }, + }); + } +} + +/** + * The Context Merge (original) node. + */ +class ContextMergeNode extends BaseContextMultiCtxInputNode { + static override title = NodeTypesString.CONTEXT_MERGE; + static override type = NodeTypesString.CONTEXT_MERGE; + static comfyClass = NodeTypesString.CONTEXT_MERGE; + + constructor(title = ContextMergeNode.title) { + super(title); + } + + static override setUp(comfyClass: ComfyNodeConstructor, nodeData: ComfyObjectInfo) { + BaseContextNode.setUp(comfyClass, nodeData, ContextMergeNode); + } + + static override onRegisteredForOverride(comfyClass: any, ctxClass: any) { + BaseContextNode.onRegisteredForOverride(comfyClass, ctxClass); + addMenuItem(ContextMergeNode, app, { + name: "Convert To Context Merge Big", + callback: (node) => { + replaceNode(node, ContextMergeBigNode.type); + }, + }); + } +} + +/** + * The Context Switch Big node. + */ +class ContextMergeBigNode extends BaseContextMultiCtxInputNode { + static override title = NodeTypesString.CONTEXT_MERGE_BIG; + static override type = NodeTypesString.CONTEXT_MERGE_BIG; + static comfyClass = NodeTypesString.CONTEXT_MERGE_BIG; + + constructor(title = ContextMergeBigNode.title) { + super(title); + } + + static override setUp(comfyClass: ComfyNodeConstructor, nodeData: ComfyObjectInfo) { + BaseContextNode.setUp(comfyClass, nodeData, ContextMergeBigNode); + } + + static override onRegisteredForOverride(comfyClass: any, ctxClass: any) { + BaseContextNode.onRegisteredForOverride(comfyClass, ctxClass); + addMenuItem(ContextMergeBigNode, app, { + name: "Convert To Context Switch", + callback: (node) => { + replaceNode(node, ContextMergeNode.type); + }, + }); + } +} + +const contextNodes = [ + ContextNode, + ContextBigNode, + ContextSwitchNode, + ContextSwitchBigNode, + ContextMergeNode, + ContextMergeBigNode, +]; +const contextTypeToServerDef: { [type: string]: ComfyObjectInfo } = {}; + +function fixBadConfigs(node: ContextNode) { + // Dumb mistake, but let's fix our mispelling. This will probably need to stay in perpetuity to + // keep any old workflows operating. + const wrongName = node.outputs.find((o, i) => o.name === "CLIP_HEIGTH"); + if (wrongName) { + wrongName.name = "CLIP_HEIGHT"; + } +} + +app.registerExtension({ + name: "rgthree.Context", + async beforeRegisterNodeDef(nodeType: ComfyNodeConstructor, nodeData: ComfyObjectInfo) { + // Loop over out context nodes and see if any match the server data. + for (const ctxClass of contextNodes) { + if (nodeData.name === ctxClass.type) { + contextTypeToServerDef[ctxClass.type] = nodeData; + ctxClass.setUp(nodeType, nodeData); + break; + } + } + }, + + async nodeCreated(node: TLGraphNode) { + const type = node.type || (node.constructor as any).type; + const serverDef = type && contextTypeToServerDef[type]; + if (serverDef) { + fixBadConfigs(node as ContextNode); + matchLocalSlotsToServer(node, IoDirection.OUTPUT, serverDef); + // Switches don't need to change inputs, only context outputs + if (!type!.includes("Switch") && !type!.includes("Merge")) { + matchLocalSlotsToServer(node, IoDirection.INPUT, serverDef); + } + // }, 100); + } + }, + + /** + * When we're loaded from the server, check if we're using an out of date version and update our + * inputs / outputs to match. + */ + async loadedGraphNode(node: TLGraphNode) { + const type = node.type || (node.constructor as any).type; + const serverDef = type && contextTypeToServerDef[type]; + if (serverDef) { + fixBadConfigs(node as ContextNode); + matchLocalSlotsToServer(node, IoDirection.OUTPUT, serverDef); + // Switches don't need to change inputs, only context outputs + if (!type!.includes("Switch") && !type!.includes("Merge")) { + matchLocalSlotsToServer(node, IoDirection.INPUT, serverDef); + } + } + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/dialog_info.ts b/rgthree-comfy/src_web/comfyui/dialog_info.ts new file mode 100644 index 0000000000000000000000000000000000000000..22e4bb53131582593dabdc2993868f02f702ae80 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/dialog_info.ts @@ -0,0 +1,406 @@ +import { RgthreeDialog, RgthreeDialogOptions } from "rgthree/common/dialog.js"; +import { + createElement as $el, + empty, + appendChildren, + getClosestOrSelf, + queryOne, + query, + setAttributes, +} from "rgthree/common/utils_dom.js"; +import { + logoCivitai, + link, + pencilColored, + diskColored, + dotdotdot, +} from "rgthree/common/media/svgs.js"; +import { RgthreeModelInfo } from "typings/rgthree.js"; +import { SERVICE as MODEL_INFO_SERVICE } from "rgthree/common/model_info_service.js"; +import { rgthree } from "./rgthree.js"; +import { MenuButton } from "rgthree/common/menu.js"; +import { generateId, injectCss } from "rgthree/common/shared_utils.js"; +import { rgthreeApi } from "rgthree/common/rgthree_api.js"; + +/** + * A dialog that displays information about a model/lora/etc. + */ +export class RgthreeInfoDialog extends RgthreeDialog { + private modifiedModelData = false; + private modelInfo: RgthreeModelInfo | null = null; + + constructor(file: string) { + const dialogOptions: RgthreeDialogOptions = { + class: "rgthree-info-dialog", + title: `

Loading...

`, + content: "
Loading..
", + onBeforeClose: () => { + return true; + }, + }; + super(dialogOptions); + this.init(file); + } + + private async init(file: string) { + const cssPromise = injectCss("rgthree/common/css/dialog_model_info.css"); + this.modelInfo = await MODEL_INFO_SERVICE.getLora(file, false, false); + await cssPromise; + this.setContent(this.getInfoContent()); + this.setTitle(this.modelInfo?.["name"] || this.modelInfo?.["file"] || "Unknown"); + this.attachEvents(); + } + + protected override getCloseEventDetail(): { detail: any } { + const detail = { + dirty: this.modifiedModelData, + }; + return { detail }; + } + + private attachEvents() { + this.contentElement.addEventListener("click", async (e: MouseEvent) => { + const target = getClosestOrSelf(e.target as HTMLElement, "[data-action]"); + const action = target?.getAttribute("data-action"); + if (!target || !action) { + return; + } + await this.handleEventAction(action, target, e); + }); + } + + private async handleEventAction(action: string, target: HTMLElement, e?: Event) { + const info = this.modelInfo!; + if (!info?.file) { + return; + } + if (action === "fetch-civitai") { + this.modelInfo = await MODEL_INFO_SERVICE.refreshLora(info.file); + this.setContent(this.getInfoContent()); + this.setTitle(this.modelInfo?.["name"] || this.modelInfo?.["file"] || "Unknown"); + } else if (action === "copy-trained-words") { + const selected = query(".-rgthree-is-selected", target.closest("tr")!); + const text = selected.map((el) => el.getAttribute("data-word")).join(", "); + await navigator.clipboard.writeText(text); + rgthree.showMessage({ + id: "copy-trained-words-" + generateId(4), + type: "success", + message: `Successfully copied ${selected.length} key word${ + selected.length === 1 ? "" : "s" + }.`, + timeout: 4000, + }); + } else if (action === "toggle-trained-word") { + target?.classList.toggle("-rgthree-is-selected"); + const tr = target.closest("tr"); + if (tr) { + const span = queryOne("td:first-child > *", tr)!; + let small = queryOne("small", span); + if (!small) { + small = $el("small", { parent: span }); + } + const num = query(".-rgthree-is-selected", tr).length; + small.innerHTML = num + ? `${num} selected | Copy` + : ""; + // this.handleEventAction('copy-trained-words', target, e); + } + } else if (action === "edit-row") { + const tr = target!.closest("tr")!; + const td = queryOne("td:nth-child(2)", tr)!; + const input = td.querySelector("input,textarea"); + if (!input) { + const fieldName = tr.dataset["fieldName"] as string; + tr.classList.add("-rgthree-editing"); + const isTextarea = fieldName === "userNote"; + const input = $el(`${isTextarea ? "textarea" : 'input[type="text"]'}`, { + value: td.textContent, + }); + input.addEventListener("keydown", (e) => { + if (!isTextarea && e.key === "Enter") { + const modified = saveEditableRow(info!, tr, true); + this.modifiedModelData = this.modifiedModelData || modified; + e.stopPropagation(); + e.preventDefault(); + } else if (e.key === "Escape") { + const modified = saveEditableRow(info!, tr, false); + this.modifiedModelData = this.modifiedModelData || modified; + e.stopPropagation(); + e.preventDefault(); + } + }); + appendChildren(empty(td), [input]); + input.focus(); + } else if (target!.nodeName.toLowerCase() === "button") { + const modified = saveEditableRow(info!, tr, true); + this.modifiedModelData = this.modifiedModelData || modified; + } + e?.preventDefault(); + e?.stopPropagation(); + } + } + + private getInfoContent() { + const info = this.modelInfo || {}; + const civitaiLink = info.links?.find((i) => i.includes("civitai.com/models")); + const html = ` +
    +
  • ${info.type || ""}
  • +
  • ${info.baseModel || ""}
  • +
  • + ${ + "" + // !civitaiLink + // ? "" + // : ` + // + // ` + } +
+ + + ${infoTableRow("File", info.file || "")} + ${infoTableRow("Hash (sha256)", info.sha256 || "")} + ${ + civitaiLink + ? infoTableRow( + "Civitai", + `${logoCivitai}View on Civitai`, + ) + : info.raw?.civitai?.error === "Model not found" + ? infoTableRow( + "Civitai", + 'Model not found', + ) + : info.raw?.civitai?.error + ? infoTableRow("Civitai", info.raw?.civitai?.error) + : !info.raw?.civitai + ? infoTableRow( + "Civitai", + ``, + ) + : "" + } + + ${infoTableRow( + "Name", + info.name || info.raw?.metadata?.ss_output_name || "", + "The name for display.", + "name", + )} + + ${ + !info.baseModelFile && !info.baseModelFile + ? "" + : infoTableRow( + "Base Model", + (info.baseModel || "") + (info.baseModelFile ? ` (${info.baseModelFile})` : ""), + ) + } + + + ${ + !info.trainedWords?.length + ? "" + : infoTableRow( + "Trained Words", + getTrainedWordsMarkup(info.trainedWords) ?? "", + "Trained words from the metadata and/or civitai. Click to select for copy.", + ) + } + + ${ + !info.raw?.metadata?.ss_clip_skip || info.raw?.metadata?.ss_clip_skip == "None" + ? "" + : infoTableRow("Clip Skip", info.raw?.metadata?.ss_clip_skip) + } + ${infoTableRow( + "Strength Min", + info.strengthMin ?? "", + "The recommended minimum strength, In the Power Lora Loader node, strength will signal when it is below this threshold.", + "strengthMin", + )} + ${infoTableRow( + "Strength Max", + info.strengthMax ?? "", + "The recommended maximum strength. In the Power Lora Loader node, strength will signal when it is above this threshold.", + "strengthMax", + )} + ${ + "" /*infoTableRow( + "User Tags", + info.userTags?.join(", ") ?? "", + "A list of tags to make filtering easier in the Power Lora Chooser.", + "userTags", + )*/ + } + ${infoTableRow( + "Additional Notes", + info.userNote ?? "", + "Additional notes you'd like to keep and reference in the info dialog.", + "userNote", + )} + +
+ +
    ${ + info.images + ?.map( + (img) => ` +
  • +
    + +
    ${imgInfoField( + "", + img.civitaiUrl + ? `civitai${link}` + : undefined, + )}${imgInfoField("seed", img.seed)}${imgInfoField("steps", img.steps)}${imgInfoField("cfg", img.cfg)}${imgInfoField("sampler", img.sampler)}${imgInfoField("model", img.model)}${imgInfoField("positive", img.positive)}${imgInfoField("negative", img.negative)}
    +
    +
  • `, + ) + .join("") ?? "" + }
+ `; + + const div = $el("div", { html }); + + if (rgthree.isDevMode()) { + setAttributes(queryOne('[stub="menu"]', div)!, { + children: [ + new MenuButton({ + icon: dotdotdot, + options: [ + { label: "More Actions", type: "title" }, + { + label: "Open API JSON", + callback: async (e: PointerEvent) => { + if (this.modelInfo?.file) { + window.open( + `rgthree/api/loras/info?file=${encodeURIComponent(this.modelInfo.file)}`, + ); + } + }, + }, + { + label: "Clear all local info", + callback: async (e: PointerEvent) => { + if (this.modelInfo?.file) { + this.modelInfo = await MODEL_INFO_SERVICE.clearLoraFetchedData( + this.modelInfo.file, + ); + this.setContent(this.getInfoContent()); + this.setTitle( + this.modelInfo?.["name"] || this.modelInfo?.["file"] || "Unknown", + ); + } + }, + }, + ], + }), + ], + }); + } + + return div; + } +} + +/** + * Generates a uniform markup string for a table row. + */ +function infoTableRow( + name: string, + value: string | number, + help: string = "", + editableFieldName = "", +) { + return ` + + ${name} ${help ? `` : ""} + ${ + String(value).startsWith("<") ? value : `${value}` + } + ${ + editableFieldName + ? `` + : "" + } + `; +} + +function getTrainedWordsMarkup(words: RgthreeModelInfo["trainedWords"]) { + let markup = `
    `; + for (const wordData of words || []) { + markup += `
  • + ${wordData.word} + ${wordData.civitai ? logoCivitai : ""} + ${wordData.count != null ? `${wordData.count}` : ""} +
  • `; + } + markup += `
`; + return markup; +} + +/** + * Saves / cancels an editable row. Returns a boolean if the data was modified. + */ +function saveEditableRow(info: RgthreeModelInfo, tr: HTMLElement, saving = true): boolean { + const fieldName = tr.dataset["fieldName"] as "file"; + const input = queryOne("input,textarea", tr)!; + let newValue = info[fieldName] ?? ""; + let modified = false; + if (saving) { + newValue = input!.value; + if (fieldName.startsWith("strength")) { + if (Number.isNaN(Number(newValue))) { + alert(`You must enter a number into the ${fieldName} field.`); + return false; + } + newValue = (Math.round(Number(newValue) * 100) / 100).toFixed(2); + } + MODEL_INFO_SERVICE.saveLoraPartial(info.file!, { [fieldName]: newValue }); + modified = true; + } + tr.classList.remove("-rgthree-editing"); + const td = queryOne("td:nth-child(2)", tr)!; + appendChildren(empty(td), [$el("span", { text: newValue })]); + return modified; +} + +function imgInfoField(label: string, value?: string | number) { + return value != null ? `${label ? `` : ""}${value}` : ""; +} diff --git a/rgthree-comfy/src_web/comfyui/display_any.ts b/rgthree-comfy/src_web/comfyui/display_any.ts new file mode 100644 index 0000000000000000000000000000000000000000..8562a5163be891ec1037c7f69aa0b7231b8f59b0 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/display_any.ts @@ -0,0 +1,73 @@ +import { app } from "scripts/app.js"; +import { ComfyWidgets } from "scripts/widgets.js"; +import type { LGraphNode as TLGraphNode } from "typings/litegraph.js"; +import type { ComfyApp, ComfyNodeConstructor, ComfyObjectInfo } from "typings/comfy.js"; +import { addConnectionLayoutSupport } from "./utils.js"; +import { rgthree } from "./rgthree.js"; + +let hasShownAlertForUpdatingInt = false; + +app.registerExtension({ + name: "rgthree.DisplayAny", + async beforeRegisterNodeDef( + nodeType: ComfyNodeConstructor, + nodeData: ComfyObjectInfo, + app: ComfyApp, + ) { + if (nodeData.name === "Display Any (rgthree)" || nodeData.name === "Display Int (rgthree)") { + const onNodeCreated = nodeType.prototype.onNodeCreated; + nodeType.prototype.onNodeCreated = function () { + onNodeCreated ? onNodeCreated.apply(this, []) : undefined; + + (this as any).showValueWidget = ComfyWidgets["STRING"]( + this, + "output", + ["STRING", { multiline: true }], + app, + ).widget; + (this as any).showValueWidget.inputEl!.readOnly = true; + (this as any).showValueWidget.serializeValue = async (node: TLGraphNode, index: number) => { + const n = + rgthree.getNodeFromInitialGraphToPromptSerializedWorkflowBecauseComfyUIBrokeStuff(node); + if (n) { + // Since we need a round trip to get the value, the serizalized value means nothing, and + // saving it to the metadata would just be confusing. So, we clear it here. + n.widgets_values![index] = ""; + } else { + console.warn( + "No serialized node found in workflow. May be attributed to " + + "https://github.com/comfyanonymous/ComfyUI/issues/2193", + ); + } + return ""; + }; + }; + + addConnectionLayoutSupport(nodeType, app, [["Left"], ["Right"]]); + + const onExecuted = nodeType.prototype.onExecuted; + nodeType.prototype.onExecuted = function (message: any) { + onExecuted?.apply(this, [message]); + (this as any).showValueWidget.value = message.text[0]; + }; + } + }, + + // This ports Display Int to DisplayAny, but ComfyUI still shows an error. + // If https://github.com/comfyanonymous/ComfyUI/issues/1527 is fixed, this could work. + // async loadedGraphNode(node: TLGraphNode) { + // if (node.type === "Display Int (rgthree)") { + // replaceNode(node, "Display Any (rgthree)", new Map([["input", "source"]])); + // if (!hasShownAlertForUpdatingInt) { + // hasShownAlertForUpdatingInt = true; + // setTimeout(() => { + // alert( + // "Don't worry, your 'Display Int' nodes have been updated to the new " + + // "'Display Any' nodes! You can ignore the error message underneath (for that node)." + + // "\n\nThanks.\n- rgthree", + // ); + // }, 128); + // } + // } + // }, +}); diff --git a/rgthree-comfy/src_web/comfyui/dynamic_context.ts b/rgthree-comfy/src_web/comfyui/dynamic_context.ts new file mode 100644 index 0000000000000000000000000000000000000000..5e0357602fb1332997be3c6216bc995bd254d07f --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/dynamic_context.ts @@ -0,0 +1,297 @@ +import {app} from "scripts/app.js"; +import { + IoDirection, + followConnectionUntilType, + getConnectedInputInfosAndFilterPassThroughs, +} from "./utils.js"; +import {rgthree} from "./rgthree.js"; +import { + SERVICE as CONTEXT_SERVICE, + InputMutation, + InputMutationOperation, +} from "./services/context_service.js"; +import {NodeTypesString} from "./constants.js"; +import {removeUnusedInputsFromEnd} from "./utils_inputs_outputs.js"; +import {INodeInputSlot, INodeOutputSlot, INodeSlot, LGraphNode, LLink} from "typings/litegraph.js"; +import {ComfyNodeConstructor, ComfyObjectInfo} from "typings/comfy.js"; +import {DynamicContextNodeBase} from "./dynamic_context_base.js"; +import {SERVICE as CONFIG_SERVICE} from "./services/config_service.js"; + +const OWNED_PREFIX = "+"; +const REGEX_OWNED_PREFIX = /^\+\s*/; +const REGEX_EMPTY_INPUT = /^\+\s*$/; + +/** + * The Dynamic Context node. + */ +export class DynamicContextNode extends DynamicContextNodeBase { + static override title = NodeTypesString.DYNAMIC_CONTEXT; + static override type = NodeTypesString.DYNAMIC_CONTEXT; + static comfyClass = NodeTypesString.DYNAMIC_CONTEXT; + + constructor(title = DynamicContextNode.title) { + super(title); + } + + override onNodeCreated() { + this.addInput("base_ctx", "RGTHREE_DYNAMIC_CONTEXT"); + this.ensureOneRemainingNewInputSlot(); + super.onNodeCreated(); + } + + override onConnectionsChange( + type: number, + slotIndex: number, + isConnected: boolean, + link: LLink, + ioSlot: INodeSlot, + ): void { + super.onConnectionsChange?.call(this, type, slotIndex, isConnected, link, ioSlot); + if (this.configuring) { + return; + } + if (type === LiteGraph.INPUT) { + if (isConnected) { + this.handleInputConnected(slotIndex); + } else { + this.handleInputDisconnected(slotIndex); + } + } + } + + override onConnectInput( + inputIndex: number, + outputType: INodeOutputSlot["type"], + outputSlot: INodeOutputSlot, + outputNode: LGraphNode, + outputIndex: number, + ): boolean { + let canConnect = true; + if (super.onConnectInput) { + canConnect = super.onConnectInput.apply(this, [...arguments] as any); + } + if ( + canConnect && + outputNode instanceof DynamicContextNode && + outputIndex === 0 && + inputIndex !== 0 + ) { + const [n, v] = rgthree.logger.warnParts( + "Currently, you can only connect a context node in the first slot.", + ); + console[n]?.call(console, ...v); + canConnect = false; + } + return canConnect; + } + + handleInputConnected(slotIndex: number) { + const ioSlot = this.inputs[slotIndex]; + const connectedIndexes = []; + if (slotIndex === 0) { + let baseNodeInfos = getConnectedInputInfosAndFilterPassThroughs(this, this, 0); + const baseNodes = baseNodeInfos.map((n) => n.node)!; + const baseNodesDynamicCtx = baseNodes[0] as DynamicContextNodeBase; + if (baseNodesDynamicCtx?.provideInputsData) { + const inputsData = CONTEXT_SERVICE.getDynamicContextInputsData(baseNodesDynamicCtx); + console.log("inputsData", inputsData); + for (const input of baseNodesDynamicCtx.provideInputsData()) { + if (input.name === "base_ctx" || input.name === "+") { + continue; + } + this.addContextInput(input.name, input.type, input.index); + this.stabilizeNames(); + } + } + } else if (this.isInputSlotForNewInput(slotIndex)) { + this.handleNewInputConnected(slotIndex); + } + } + + isInputSlotForNewInput(slotIndex: number) { + const ioSlot = this.inputs[slotIndex]; + return ioSlot && ioSlot.name === "+" && ioSlot.type === "*"; + } + + handleNewInputConnected(slotIndex: number) { + if (!this.isInputSlotForNewInput(slotIndex)) { + throw new Error('Expected the incoming slot index to be the "new input" input.'); + } + const ioSlot = this.inputs[slotIndex]!; + let cxn = null; + if (ioSlot.link != null) { + cxn = followConnectionUntilType(this, IoDirection.INPUT, slotIndex, true); + } + if (cxn?.type && cxn?.name) { + let name = this.addOwnedPrefix(this.getNextUniqueNameForThisNode(cxn.name)); + if (name.match(/^\+\s*[A-Z_]+(\.\d+)?$/)) { + name = name.toLowerCase(); + } + ioSlot.name = name; + ioSlot.type = cxn.type as string; + ioSlot.removable = true; + while (!this.outputs[slotIndex]) { + this.addOutput("*", "*"); + } + this.outputs[slotIndex]!.type = cxn.type as string; + this.outputs[slotIndex]!.name = this.stripOwnedPrefix(name).toLocaleUpperCase(); + // This is a dumb override for ComfyUI's widgetinputs issues. + if (cxn.type === "COMBO" || cxn.type.includes(",") || Array.isArray(cxn.type)) { + (this.outputs[slotIndex] as any).widget = true; + } + this.inputsMutated({ + operation: InputMutationOperation.ADDED, + node: this, + slotIndex, + slot: ioSlot, + }); + this.stabilizeNames(); + this.ensureOneRemainingNewInputSlot(); + } + } + + handleInputDisconnected(slotIndex: number) { + const inputs = this.getContextInputsList(); + if (slotIndex === 0) { + for (let index = inputs.length - 1; index > 0; index--) { + if (index === 0 || index === inputs.length - 1) { + continue; + } + const input = inputs[index]!; + if (!this.isOwnedInput(input.name)) { + if (input.link || this.outputs[index]?.links?.length) { + this.renameContextInput(index, input.name, true); + } else { + this.removeContextInput(index); + } + } + } + this.setSize(this.computeSize()); + this.setDirtyCanvas(true, true); + } + } + + ensureOneRemainingNewInputSlot() { + removeUnusedInputsFromEnd(this, 1, REGEX_EMPTY_INPUT); + this.addInput(OWNED_PREFIX, "*"); + } + + getNextUniqueNameForThisNode(desiredName: string) { + const inputs = this.getContextInputsList(); + const allExistingKeys = inputs.map((i) => this.stripOwnedPrefix(i.name).toLocaleUpperCase()); + desiredName = this.stripOwnedPrefix(desiredName); + let newName = desiredName; + let n = 0; + while (allExistingKeys.includes(newName.toLocaleUpperCase())) { + newName = `${desiredName}.${++n}`; + } + return newName; + } + + override removeInput(slotIndex: number) { + const slot = this.inputs[slotIndex]!; + super.removeInput(slotIndex); + if (this.outputs[slotIndex]) { + this.removeOutput(slotIndex); + } + this.inputsMutated({operation: InputMutationOperation.REMOVED, node: this, slotIndex, slot}); + this.stabilizeNames(); + } + + stabilizeNames() { + const inputs = this.getContextInputsList(); + const names: string[] = []; + for (const [index, input] of inputs.entries()) { + if (index === 0 || index === inputs.length - 1) { + continue; + } + input.label = undefined; + this.outputs[index]!.label = undefined; + let origName = this.stripOwnedPrefix(input.name).replace(/\.\d+$/, ""); + let name = input.name; + if (!this.isOwnedInput(name)) { + names.push(name.toLocaleUpperCase()); + } else { + let n = 0; + name = this.addOwnedPrefix(origName); + while (names.includes(this.stripOwnedPrefix(name).toLocaleUpperCase())) { + name = `${this.addOwnedPrefix(origName)}.${++n}`; + } + names.push(this.stripOwnedPrefix(name).toLocaleUpperCase()); + if (input.name !== name) { + this.renameContextInput(index, name); + } + } + } + } + + override getSlotMenuOptions(slot: { + slot: number; + input?: INodeInputSlot | undefined; + output?: INodeOutputSlot | undefined; + }) { + const editable = this.isOwnedInput(slot.input!.name) && this.type !== "*"; + return [ + { + content: "✏️ Rename Input", + disabled: !editable, + callback: () => { + var dialog = app.canvas.createDialog( + "Name", + {}, + ); + var dialogInput = dialog.querySelector("input")!; + if (dialogInput) { + dialogInput.value = this.stripOwnedPrefix(slot.input!.name || ""); + } + var inner = () => { + this.handleContextMenuRenameInputDialog(slot.slot, dialogInput.value); + dialog.close(); + }; + dialog.querySelector("button")!.addEventListener("click", inner); + dialogInput.addEventListener("keydown", (e) => { + dialog.is_modified = true; + if (e.keyCode == 27) { + dialog.close(); + } else if (e.keyCode == 13) { + inner(); + } else if (e.keyCode != 13 && (e.target as HTMLElement)?.localName != "textarea") { + return; + } + e.preventDefault(); + e.stopPropagation(); + }); + dialogInput.focus(); + }, + }, + { + content: "🗑️ Delete Input", + disabled: !editable, + callback: () => { + this.removeInput(slot.slot); + }, + }, + ]; + } + + handleContextMenuRenameInputDialog(slotIndex: number, value: string) { + app.graph.beforeChange(); + this.renameContextInput(slotIndex, value); + this.stabilizeNames(); + this.setDirtyCanvas(true, true); + app.graph.afterChange(); + } +} + +const contextDynamicNodes = [DynamicContextNode]; +app.registerExtension({ + name: "rgthree.DynamicContext", + async beforeRegisterNodeDef(nodeType: ComfyNodeConstructor, nodeData: ComfyObjectInfo) { + if (!CONFIG_SERVICE.getConfigValue("unreleased.dynamic_context.enabled")) { + return; + } + if (nodeData.name === DynamicContextNode.type) { + DynamicContextNode.setUp(nodeType, nodeData); + } + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/dynamic_context_base.ts b/rgthree-comfy/src_web/comfyui/dynamic_context_base.ts new file mode 100644 index 0000000000000000000000000000000000000000..6fe703b9c3636fcf5e36098022a6bdc72349b683 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/dynamic_context_base.ts @@ -0,0 +1,237 @@ +import type {INodeInputSlot} from "typings/litegraph.js"; + +import {BaseContextNode} from "./context.js"; +import {ComfyNodeConstructor, ComfyObjectInfo} from "typings/comfy.js"; +import {RgthreeBaseServerNode} from "./base_node.js"; +import {moveArrayItem, wait} from "rgthree/common/shared_utils.js"; +import {RgthreeInvisibleWidget} from "./utils_widgets.js"; +import { + getContextOutputName, + InputMutation, + InputMutationOperation, +} from "./services/context_service.js"; +import {app} from "scripts/app.js"; +import {SERVICE as CONTEXT_SERVICE} from "./services/context_service.js"; + +const OWNED_PREFIX = "+"; +const REGEX_OWNED_PREFIX = /^\+\s*/; +const REGEX_EMPTY_INPUT = /^\+\s*$/; + +export type InputLike = { + name: string; + type: string | -1; + label?: string; + link: number | null; + removable?: boolean; +}; + +/** + * The base context node that contains some shared between DynamicContext nodes. Not labels + * `abstract` so we can reference `this` in static methods. + */ +export class DynamicContextNodeBase extends BaseContextNode { + protected readonly hasShadowInputs: boolean = false; + + getContextInputsList(): InputLike[] { + return this.inputs; + } + + provideInputsData() { + const inputs = this.getContextInputsList(); + return inputs + .map((input, index) => ({ + name: this.stripOwnedPrefix(input.name), + type: String(input.type), + index, + })) + .filter((i) => i.type !== "*"); + } + + addOwnedPrefix(name: string) { + return `+ ${this.stripOwnedPrefix(name)}`; + } + + isOwnedInput(inputOrName: string | null | INodeInputSlot) { + const name = typeof inputOrName == "string" ? inputOrName : inputOrName?.name || ""; + return REGEX_OWNED_PREFIX.test(name); + } + + stripOwnedPrefix(name: string) { + return name.replace(REGEX_OWNED_PREFIX, ""); + } + + // handleUpstreamMutation(mutation: InputMutation) { + // throw new Error('handleUpstreamMutation not overridden!') + // } + + handleUpstreamMutation(mutation: InputMutation) { + console.log(`[node ${this.id}] handleUpstreamMutation`, mutation); + if (mutation.operation === InputMutationOperation.ADDED) { + const slot = mutation.slot; + if (!slot) { + throw new Error("Cannot have an ADDED mutation without a provided slot data."); + } + this.addContextInput( + this.stripOwnedPrefix(slot.name), + slot.type as string, + mutation.slotIndex, + ); + return; + } + if (mutation.operation === InputMutationOperation.REMOVED) { + const slot = mutation.slot; + if (!slot) { + throw new Error("Cannot have an REMOVED mutation without a provided slot data."); + } + this.removeContextInput(mutation.slotIndex); + return; + } + if (mutation.operation === InputMutationOperation.RENAMED) { + const slot = mutation.slot; + if (!slot) { + throw new Error("Cannot have an RENAMED mutation without a provided slot data."); + } + this.renameContextInput(mutation.slotIndex, slot.name); + return; + } + } + override clone() { + const cloned = super.clone(); + while (cloned.inputs.length > 1) { + cloned.removeInput(cloned.inputs.length - 1); + } + while (cloned.widgets.length > 1) { + cloned.removeWidget(cloned.widgets.length - 1); + } + while (cloned.outputs.length > 1) { + cloned.removeOutput(cloned.outputs.length - 1); + } + return cloned; + } + + /** + * Adds the basic output_keys widget. Should be called _after_ specific nodes setup their inputs + * or widgets. + */ + override onNodeCreated() { + const node = this; + this.addCustomWidget( + new RgthreeInvisibleWidget("output_keys", "RGTHREE_DYNAMIC_CONTEXT_OUTPUTS", "", () => { + return (node.outputs || []) + .map((o, i) => i > 0 && o.name) + .filter((n) => n !== false) + .join(","); + }), + ); + } + + addContextInput(name: string, type: string, slot = -1) { + const inputs = this.getContextInputsList(); + if (this.hasShadowInputs) { + inputs.push({name, type, link: null}); + } else { + this.addInput(name, type); + } + if (slot > -1) { + moveArrayItem(inputs, inputs.length - 1, slot); + } else { + slot = inputs.length - 1; + } + if (type !== "*") { + const output = this.addOutput(getContextOutputName(name), type); + if (type === "COMBO" || String(type).includes(",") || Array.isArray(type)) { + (output as any).widget = true; + } + if (slot > -1) { + moveArrayItem(this.outputs, this.outputs.length - 1, slot); + } + } + this.fixInputsOutputsLinkSlots(); + this.inputsMutated({ + operation: InputMutationOperation.ADDED, + node: this, + slotIndex: slot, + slot: inputs[slot]!, + }); + } + + removeContextInput(slotIndex: number) { + if (this.hasShadowInputs) { + const inputs = this.getContextInputsList(); + const input = inputs.splice(slotIndex, 1)[0]; + if (this.outputs[slotIndex]) { + this.removeOutput(slotIndex); + } + } else { + this.removeInput(slotIndex); + } + } + + renameContextInput(index: number, newName: string, forceOwnBool: boolean | null = null) { + const inputs = this.getContextInputsList(); + const input = inputs[index]!; + const oldName = input.name; + newName = this.stripOwnedPrefix(newName.trim() || this.getSlotDefaultInputLabel(index)); + if (forceOwnBool === true || (this.isOwnedInput(oldName) && forceOwnBool !== false)) { + newName = this.addOwnedPrefix(newName); + } + if (oldName !== newName) { + input.name = newName; + input.removable = this.isOwnedInput(newName); + this.outputs[index]!.name = getContextOutputName(inputs[index]!.name); + this.inputsMutated({ + node: this, + operation: InputMutationOperation.RENAMED, + slotIndex: index, + slot: input, + }); + } + } + + getSlotDefaultInputLabel(slotIndex: number) { + const inputs = this.getContextInputsList(); + const input = inputs[slotIndex]!; + let defaultLabel = this.stripOwnedPrefix(input.name).toLowerCase(); + return defaultLabel.toLocaleLowerCase(); + } + + inputsMutated(mutation: InputMutation) { + CONTEXT_SERVICE.onInputChanges(this, mutation); + } + + fixInputsOutputsLinkSlots() { + if (!this.hasShadowInputs) { + const inputs = this.getContextInputsList(); + for (let index = inputs.length - 1; index > 0; index--) { + const input = inputs[index]!; + if ((input === null || input === void 0 ? void 0 : input.link) != null) { + app.graph.links[input.link!]!.target_slot = index; + } + } + } + const outputs = this.outputs; + for (let index = outputs.length - 1; index > 0; index--) { + const output = outputs[index]; + if (output) { + output.nameLocked = true; + for (const link of output.links || []) { + app.graph.links[link!]!.origin_slot = index; + } + } + } + } + + static override setUp(comfyClass: ComfyNodeConstructor, nodeData: ComfyObjectInfo) { + RgthreeBaseServerNode.registerForOverride(comfyClass, nodeData, this); + // [🤮] ComfyUI only adds "required" inputs to the outputs list when dragging an output to + // empty space, but since RGTHREE_CONTEXT is optional, it doesn't get added to the menu because + // ...of course. So, we'll manually add it. Of course, we also have to do this in a timeout + // because ComfyUI clears out `LiteGraph.slot_types_default_out` in its own 'Comfy.SlotDefaults' + // extension and we need to wait for that to happen. + wait(500).then(() => { + LiteGraph.slot_types_default_out["RGTHREE_DYNAMIC_CONTEXT"] = + LiteGraph.slot_types_default_out["RGTHREE_DYNAMIC_CONTEXT"] || []; + LiteGraph.slot_types_default_out["RGTHREE_DYNAMIC_CONTEXT"].push(comfyClass.comfyClass); + }); + } +} diff --git a/rgthree-comfy/src_web/comfyui/dynamic_context_switch.ts b/rgthree-comfy/src_web/comfyui/dynamic_context_switch.ts new file mode 100644 index 0000000000000000000000000000000000000000..34c7db7497c9c3acec5af047e487a813ed52daae --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/dynamic_context_switch.ts @@ -0,0 +1,207 @@ +import type {ComfyNodeConstructor, ComfyObjectInfo} from "typings/comfy.js"; +import type {INodeSlot, LGraphNode, LLink, LGraphCanvas} from "typings/litegraph.js"; + +import {app} from "scripts/app.js"; +import {DynamicContextNodeBase, InputLike} from "./dynamic_context_base.js"; +import {NodeTypesString} from "./constants.js"; +import { + InputMutation, + SERVICE as CONTEXT_SERVICE, + stripContextInputPrefixes, + getContextOutputName, +} from "./services/context_service.js"; +import {getConnectedInputNodesAndFilterPassThroughs} from "./utils.js"; +import {debounce, moveArrayItem} from "rgthree/common/shared_utils.js"; +import {measureText} from "./utils_canvas.js"; +import {SERVICE as CONFIG_SERVICE} from "./services/config_service.js"; + +type ShadowInputData = { + node: LGraphNode; + slot: number; + shadowIndex: number; + shadowIndexIfShownSingularly: number; + shadowIndexFull: number; + nodeIndex: number; + type: string | -1; + name: string; + key: string; + // isDuplicatedBefore: boolean, + duplicatesBefore: number[]; + duplicatesAfter: number[]; +}; + +/** + * The Context Switch node. + */ +class DynamicContextSwitchNode extends DynamicContextNodeBase { + static override title = NodeTypesString.DYNAMIC_CONTEXT_SWITCH; + static override type = NodeTypesString.DYNAMIC_CONTEXT_SWITCH; + static comfyClass = NodeTypesString.DYNAMIC_CONTEXT_SWITCH; + + protected override readonly hasShadowInputs = true; + + // override hasShadowInputs = true; + + /** + * We should be able to assume that `lastInputsList` is the input list after the last, major + * synchronous change. Which should mean, if we're handling a change that is currently live, but + * not represented in our node (like, an upstream node has already removed an input), then we + * should be able to compar the current InputList to this `lastInputsList`. + */ + lastInputsList: ShadowInputData[] = []; + + private shadowInputs: (InputLike & {count: number})[] = [ + {name: "base_ctx", type: "RGTHREE_DYNAMIC_CONTEXT", link: null, count: 0}, + ]; + + constructor(title = DynamicContextSwitchNode.title) { + super(title); + } + + override getContextInputsList() { + return this.shadowInputs; + } + override handleUpstreamMutation(mutation: InputMutation) { + this.scheduleHardRefresh(); + } + + override onConnectionsChange( + type: number, + slotIndex: number, + isConnected: boolean, + link: LLink, + ioSlot: INodeSlot, + ): void { + super.onConnectionsChange?.call(this, type, slotIndex, isConnected, link, ioSlot); + if (this.configuring) { + return; + } + if (type === LiteGraph.INPUT) { + this.scheduleHardRefresh(); + } + } + + scheduleHardRefresh(ms = 64) { + return debounce(() => { + this.refreshInputsAndOutputs(); + }, ms); + } + + override onNodeCreated() { + this.addInput("ctx_1", "RGTHREE_DYNAMIC_CONTEXT"); + this.addInput("ctx_2", "RGTHREE_DYNAMIC_CONTEXT"); + this.addInput("ctx_3", "RGTHREE_DYNAMIC_CONTEXT"); + this.addInput("ctx_4", "RGTHREE_DYNAMIC_CONTEXT"); + this.addInput("ctx_5", "RGTHREE_DYNAMIC_CONTEXT"); + super.onNodeCreated(); + } + + override addContextInput(name: string, type: string, slot?: number): void {} + + /** + * This is a "hard" refresh of the list, but looping over the actual context inputs, and + * recompiling the shadowInputs and outputs. + */ + private refreshInputsAndOutputs() { + const inputs: (InputLike & {count: number})[] = [ + {name: "base_ctx", type: "RGTHREE_DYNAMIC_CONTEXT", link: null, count: 0}, + ]; + let numConnected = 0; + for (let i = 0; i < this.inputs.length; i++) { + const childCtxs = getConnectedInputNodesAndFilterPassThroughs( + this, + this, + i, + ) as DynamicContextNodeBase[]; + if (childCtxs.length > 1) { + throw new Error("How is there more than one input?"); + } + const ctx = childCtxs[0]; + if (!ctx) continue; + numConnected++; + const slotsData = CONTEXT_SERVICE.getDynamicContextInputsData(ctx); + console.log(slotsData); + for (const slotData of slotsData) { + const found = inputs.find( + (n) => getContextOutputName(slotData.name) === getContextOutputName(n.name), + ); + if (found) { + found.count += 1; + continue; + } + inputs.push({ + name: slotData.name, + type: slotData.type, + link: null, + count: 1, + }); + } + } + this.shadowInputs = inputs; + // First output is always CONTEXT, so "p" is the offset. + let i = 0; + for (i; i < this.shadowInputs.length; i++) { + const data = this.shadowInputs[i]!; + let existing = this.outputs.find( + (o) => getContextOutputName(o.name) === getContextOutputName(data.name), + ); + if (!existing) { + existing = this.addOutput(getContextOutputName(data.name), data.type); + } + moveArrayItem(this.outputs, existing, i); + delete existing.rgthree_status; + if (data.count !== numConnected) { + existing.rgthree_status = "WARN"; + } + } + while (this.outputs[i]) { + const output = this.outputs[i]; + if (output?.links?.length) { + output.rgthree_status = "ERROR"; + i++; + } else { + this.removeOutput(i); + } + } + this.fixInputsOutputsLinkSlots(); + } + + override onDrawForeground(ctx: CanvasRenderingContext2D, canvas: LGraphCanvas): void { + const low_quality = (canvas?.ds?.scale ?? 1) < 0.6; + if (low_quality || this.size[0] <= 10) { + return; + } + let y = LiteGraph.NODE_SLOT_HEIGHT - 1; + const w = this.size[0]; + ctx.save(); + ctx.font = "normal " + LiteGraph.NODE_SUBTEXT_SIZE + "px Arial"; + ctx.textAlign = "right"; + + for (const output of this.outputs) { + if (!output.rgthree_status) { + y += LiteGraph.NODE_SLOT_HEIGHT; + continue; + } + const x = w - 20 - measureText(ctx, output.name); + if (output.rgthree_status === "ERROR") { + ctx.fillText("🛑", x, y); + } else if (output.rgthree_status === "WARN") { + ctx.fillText("⚠️", x, y); + } + y += LiteGraph.NODE_SLOT_HEIGHT; + } + ctx.restore(); + } +} + +app.registerExtension({ + name: "rgthree.DynamicContextSwitch", + async beforeRegisterNodeDef(nodeType: ComfyNodeConstructor, nodeData: ComfyObjectInfo) { + if (!CONFIG_SERVICE.getConfigValue("unreleased.dynamic_context.enabled")) { + return; + } + if (nodeData.name === DynamicContextSwitchNode.type) { + DynamicContextSwitchNode.setUp(nodeType, nodeData); + } + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/fast_actions_button.ts b/rgthree-comfy/src_web/comfyui/fast_actions_button.ts new file mode 100644 index 0000000000000000000000000000000000000000..aebada98d6857940615295c8f4c4e3d9e7e3b0fc --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/fast_actions_button.ts @@ -0,0 +1,337 @@ +import type { RgthreeBaseVirtualNodeConstructor } from "typings/rgthree.js"; +import type { ComfyApp, ComfyWidget } from "typings/comfy.js"; +import type { IWidget, LGraph, LGraphNode, SerializedLGraphNode } from "typings/litegraph.js"; +import type { RgthreeBaseNode } from "./base_node.js"; + +import { app } from "scripts/app.js"; +import { BaseAnyInputConnectedNode } from "./base_any_input_connected_node.js"; +import { NodeTypesString } from "./constants.js"; +import { addMenuItem } from "./utils.js"; +import { rgthree } from "./rgthree.js"; + +const MODE_ALWAYS = 0; +const MODE_MUTE = 2; +const MODE_BYPASS = 4; + +/** + * The Fast Actions Button. + * + * This adds a button that the user can connect any node to and then choose an action to take on + * that node when the button is pressed. Default actions are "Mute," "Bypass," and "Enable," but + * Nodes can expose actions additional actions that can then be called back. + */ +class FastActionsButton extends BaseAnyInputConnectedNode { + static override type = NodeTypesString.FAST_ACTIONS_BUTTON; + static override title = NodeTypesString.FAST_ACTIONS_BUTTON; + override comfyClass = NodeTypesString.FAST_ACTIONS_BUTTON; + + readonly logger = rgthree.newLogSession("[FastActionsButton]"); + + static "@buttonText" = { type: "string" }; + static "@shortcutModifier" = { + type: "combo", + values: ["ctrl", "alt", "shift"], + }; + static "@shortcutKey" = { type: "string" }; + + static collapsible = false; + + override readonly isVirtualNode = true; + + override serialize_widgets = true; + + readonly buttonWidget: IWidget; + + readonly widgetToData = new Map(); + readonly nodeIdtoFunctionCache = new Map(); + + readonly keypressBound; + readonly keyupBound; + + private executingFromShortcut = false; + + constructor(title?: string) { + super(title); + this.properties["buttonText"] = "🎬 Action!"; + this.properties["shortcutModifier"] = "alt"; + this.properties["shortcutKey"] = ""; + this.buttonWidget = this.addWidget( + "button", + this.properties["buttonText"], + null, + () => { + this.executeConnectedNodes(); + }, + { serialize: false }, + ); + + this.keypressBound = this.onKeypress.bind(this); + this.keyupBound = this.onKeyup.bind(this); + this.onConstructed(); + } + + /** When we're given data to configure, like from a PNG or JSON. */ + override configure(info: SerializedLGraphNode): void { + super.configure(info); + // Since we add the widgets dynamically, we need to wait to set their values + // with a short timeout. + setTimeout(() => { + if (info.widgets_values) { + for (let [index, value] of info.widgets_values.entries()) { + if (index > 0) { + if (value.startsWith("comfy_action:")) { + value = value.replace("comfy_action:", ""); + this.addComfyActionWidget(index, value); + } + if (this.widgets[index]) { + this.widgets[index]!.value = value; + } + } + } + } + }, 100); + } + + override clone() { + const cloned = super.clone(); + cloned.properties["buttonText"] = "🎬 Action!"; + cloned.properties["shortcutKey"] = ""; + return cloned; + } + + override onAdded(graph: LGraph): void { + window.addEventListener("keydown", this.keypressBound); + window.addEventListener("keyup", this.keyupBound); + } + + override onRemoved(): void { + window.removeEventListener("keydown", this.keypressBound); + window.removeEventListener("keyup", this.keyupBound); + } + + async onKeypress(event: KeyboardEvent) { + const target = (event.target as HTMLElement)!; + if ( + this.executingFromShortcut || + target.localName == "input" || + target.localName == "textarea" + ) { + return; + } + if ( + this.properties["shortcutKey"].trim() && + this.properties["shortcutKey"].toLowerCase() === event.key.toLowerCase() + ) { + const shortcutModifier = this.properties["shortcutModifier"]; + let good = shortcutModifier === "ctrl" && event.ctrlKey; + good = good || (shortcutModifier === "alt" && event.altKey); + good = good || (shortcutModifier === "shift" && event.shiftKey); + good = good || (shortcutModifier === "meta" && event.metaKey); + if (good) { + setTimeout(() => { + this.executeConnectedNodes(); + }, 20); + this.executingFromShortcut = true; + event.preventDefault(); + event.stopImmediatePropagation(); + app.canvas.dirty_canvas = true; + return false; + } + } + return; + } + + onKeyup(event: KeyboardEvent) { + const target = (event.target as HTMLElement)!; + if (target.localName == "input" || target.localName == "textarea") { + return; + } + this.executingFromShortcut = false; + } + + override onPropertyChanged(property: string, value: any, _prevValue: any): boolean | void { + if (property == "buttonText") { + this.buttonWidget.name = value; + } + if (property == "shortcutKey") { + value = value.trim(); + this.properties["shortcutKey"] = (value && value[0].toLowerCase()) || ""; + } + } + + override handleLinkedNodesStabilization(linkedNodes: LGraphNode[]) { + // Remove any widgets and data for widgets that are no longer linked. + for (const [widget, data] of this.widgetToData.entries()) { + if (!data.node) { + continue; + } + if (!linkedNodes.includes(data.node)) { + const index = this.widgets.indexOf(widget); + if (index > -1) { + this.widgetToData.delete(widget); + this.removeWidget(widget); + } else { + const [m, a] = this.logger.debugParts("Connected widget is not in widgets... weird."); + console[m]?.(...a); + } + } + } + + const badNodes: LGraphNode[] = []; // Nodes that are deleted elsewhere may not exist in linkedNodes. + let indexOffset = 1; // Start with button, increment when we hit a non-node widget (like comfy) + for (const [index, node] of linkedNodes.entries()) { + // Sometimes linkedNodes is stale. + if (!node) { + const [m, a] = this.logger.debugParts("linkedNode provided that does not exist. "); + console[m]?.(...a); + badNodes.push(node); + continue; + } + let widgetAtSlot = this.widgets[index + indexOffset]; + if (widgetAtSlot && this.widgetToData.get(widgetAtSlot)?.comfy) { + indexOffset++; + widgetAtSlot = this.widgets[index + indexOffset]; + } + + if (!widgetAtSlot || this.widgetToData.get(widgetAtSlot)?.node?.id !== node.id) { + // Find the next widget that matches the node. + let widget: IWidget | null = null; + for (let i = index + indexOffset; i < this.widgets.length; i++) { + if (this.widgetToData.get(this.widgets[i]!)?.node?.id === node.id) { + widget = this.widgets.splice(i, 1)[0]!; + this.widgets.splice(index + indexOffset, 0, widget); + break; + } + } + if (!widget) { + // Add a widget at this spot. + const exposedActions: string[] = (node.constructor as any).exposedActions || []; + widget = this.addWidget("combo", node.title, "None", "", { + values: ["None", "Mute", "Bypass", "Enable", ...exposedActions], + }); + (widget as ComfyWidget).serializeValue = async (_node: LGraphNode, _index: number) => { + return widget?.value; + }; + this.widgetToData.set(widget, { node }); + } + } + } + + // Go backwards through widgets, and remove any that are not in out widgetToData + for (let i = this.widgets.length - 1; i > linkedNodes.length + indexOffset - 1; i--) { + const widgetAtSlot = this.widgets[i]; + if (widgetAtSlot && this.widgetToData.get(widgetAtSlot)?.comfy) { + continue; + } + this.removeWidget(widgetAtSlot); + } + } + + override removeWidget(widgetOrSlot?: number | IWidget): void { + const widget = typeof widgetOrSlot === "number" ? this.widgets[widgetOrSlot] : widgetOrSlot; + if (widget && this.widgetToData.has(widget)) { + this.widgetToData.delete(widget); + } + super.removeWidget(widgetOrSlot); + } + + /** + * Runs through the widgets, and executes the actions. + */ + async executeConnectedNodes() { + for (const widget of this.widgets) { + if (widget == this.buttonWidget) { + continue; + } + const action = widget.value; + const { comfy, node } = this.widgetToData.get(widget) ?? {}; + if (comfy) { + if (action === "Queue Prompt") { + await comfy.queuePrompt(0); + } + continue; + } + if (node) { + if (action === "Mute") { + node.mode = MODE_MUTE; + } else if (action === "Bypass") { + node.mode = MODE_BYPASS; + } else if (action === "Enable") { + node.mode = MODE_ALWAYS; + } + // If there's a handleAction, always call it. + if ((node as RgthreeBaseNode).handleAction) { + await (node as RgthreeBaseNode).handleAction(action); + } + app.graph.change(); + continue; + } + console.warn("Fast Actions Button has a widget without correct data."); + } + } + + /** + * Adds a ComfyActionWidget at the provided slot (or end). + */ + addComfyActionWidget(slot?: number, value?: string) { + let widget = this.addWidget( + "combo", + "Comfy Action", + "None", + () => { + if (widget.value.startsWith("MOVE ")) { + this.widgets.push(this.widgets.splice(this.widgets.indexOf(widget), 1)[0]!); + widget.value = (widget as any)["lastValue_"]; + } else if (widget.value.startsWith("REMOVE ")) { + this.removeWidget(widget); + } + (widget as any)["lastValue_"] = widget.value; + }, + { + values: ["None", "Queue Prompt", "REMOVE Comfy Action", "MOVE to end"], + }, + ); + (widget as any)["lastValue_"] = value; + + (widget as ComfyWidget).serializeValue = async (_node: LGraphNode, _index: number) => { + return `comfy_app:${widget?.value}`; + }; + this.widgetToData.set(widget, { comfy: app }); + + if (slot != null) { + this.widgets.splice(slot, 0, this.widgets.splice(this.widgets.indexOf(widget), 1)[0]!); + } + return widget; + } + + override onSerialize(o: SerializedLGraphNode) { + super.onSerialize && super.onSerialize(o); + for (let [index, value] of (o.widgets_values || []).entries()) { + if (this.widgets[index]?.name === "Comfy Action") { + o.widgets_values![index] = `comfy_action:${value}`; + } + } + } + + static override setUp() { + super.setUp(); + addMenuItem(this, app, { + name: "➕ Append a Comfy Action", + callback: (nodeArg: LGraphNode) => { + (nodeArg as FastActionsButton).addComfyActionWidget(); + }, + }); + } +} + +app.registerExtension({ + name: "rgthree.FastActionsButton", + registerCustomNodes() { + FastActionsButton.setUp(); + }, + loadedGraphNode(node: LGraphNode) { + if (node.type == FastActionsButton.title) { + (node as FastActionsButton)._tempWidth = node.size[0]; + } + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/fast_groups_bypasser.ts b/rgthree-comfy/src_web/comfyui/fast_groups_bypasser.ts new file mode 100644 index 0000000000000000000000000000000000000000..64c5fa4bc6ae469d14f46903d9aadd2c33e2b31c --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/fast_groups_bypasser.ts @@ -0,0 +1,37 @@ +import { RgthreeBaseVirtualNodeConstructor } from "typings/rgthree.js"; +import { app } from "scripts/app.js"; +import { NodeTypesString } from "./constants.js"; +import { BaseFastGroupsModeChanger } from "./fast_groups_muter.js"; + +/** + * Fast Bypasser implementation that looks for groups in the workflow and adds toggles to mute them. + */ +export class FastGroupsBypasser extends BaseFastGroupsModeChanger { + static override type = NodeTypesString.FAST_GROUPS_BYPASSER; + static override title = NodeTypesString.FAST_GROUPS_BYPASSER; + override comfyClass = NodeTypesString.FAST_GROUPS_BYPASSER; + + static override exposedActions = ["Bypass all", "Enable all", "Toggle all"]; + + protected override helpActions = "bypass and enable"; + + override readonly modeOn = LiteGraph.ALWAYS; + override readonly modeOff = 4; // Used by Comfy for "bypass" + + constructor(title = FastGroupsBypasser.title) { + super(title); + this.onConstructed(); + } +} + +app.registerExtension({ + name: "rgthree.FastGroupsBypasser", + registerCustomNodes() { + FastGroupsBypasser.setUp(); + }, + loadedGraphNode(node: FastGroupsBypasser) { + if (node.type == FastGroupsBypasser.title) { + node.tempSize = [...node.size]; + } + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/fast_groups_muter.ts b/rgthree-comfy/src_web/comfyui/fast_groups_muter.ts new file mode 100644 index 0000000000000000000000000000000000000000..fa056948c3a0e4cfd84772507272f94e06f17487 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/fast_groups_muter.ts @@ -0,0 +1,502 @@ +import { app } from "scripts/app.js"; +import { RgthreeBaseVirtualNode } from "./base_node.js"; +import { NodeTypesString } from "./constants.js"; +import { + type LGraphNode, + type LGraph as TLGraph, + LGraphCanvas as TLGraphCanvas, + Vector2, + SerializedLGraphNode, + IWidget, +} from "typings/litegraph.js"; +import { SERVICE as FAST_GROUPS_SERVICE } from "./services/fast_groups_service.js"; +import { drawNodeWidget, fitString } from "./utils_canvas.js"; +import { RgthreeBaseVirtualNodeConstructor } from "typings/rgthree.js"; + +const PROPERTY_SORT = "sort"; +const PROPERTY_SORT_CUSTOM_ALPHA = "customSortAlphabet"; +const PROPERTY_MATCH_COLORS = "matchColors"; +const PROPERTY_MATCH_TITLE = "matchTitle"; +const PROPERTY_SHOW_NAV = "showNav"; +const PROPERTY_RESTRICTION = "toggleRestriction"; + +/** + * Fast Muter implementation that looks for groups in the workflow and adds toggles to mute them. + */ +export abstract class BaseFastGroupsModeChanger extends RgthreeBaseVirtualNode { + static override type = NodeTypesString.FAST_GROUPS_MUTER; + static override title = NodeTypesString.FAST_GROUPS_MUTER; + + static override exposedActions = ["Mute all", "Enable all", "Toggle all"]; + + readonly modeOn: number = LiteGraph.ALWAYS; + readonly modeOff: number = LiteGraph.NEVER; + + private debouncerTempWidth: number = 0; + tempSize: Vector2 | null = null; + + // We don't need to serizalize since we'll just be checking group data on startup anyway + override serialize_widgets = false; + + protected helpActions = "mute and unmute"; + + static "@matchColors" = { type: "string" }; + static "@matchTitle" = { type: "string" }; + static "@showNav" = { type: "boolean" }; + static "@sort" = { + type: "combo", + values: ["position", "alphanumeric", "custom alphabet"], + }; + static "@customSortAlphabet" = { type: "string" }; + + static "@toggleRestriction" = { + type: "combo", + values: ["default", "max one", "always one"], + }; + + constructor(title = FastGroupsMuter.title) { + super(title); + this.properties[PROPERTY_MATCH_COLORS] = ""; + this.properties[PROPERTY_MATCH_TITLE] = ""; + this.properties[PROPERTY_SHOW_NAV] = true; + this.properties[PROPERTY_SORT] = "position"; + this.properties[PROPERTY_SORT_CUSTOM_ALPHA] = ""; + this.properties[PROPERTY_RESTRICTION] = "default"; + } + + override onConstructed(): boolean { + this.addOutput("OPT_CONNECTION", "*"); + return super.onConstructed(); + } + + override configure(info: SerializedLGraphNode): void { + // Patch a small issue (~14h) where multiple OPT_CONNECTIONS may have been created. + // https://github.com/rgthree/rgthree-comfy/issues/206 + // TODO: This can probably be removed within a few weeks. + if (info.outputs?.length) { + info.outputs.length = 1; + } + super.configure(info); + } + + override onAdded(graph: TLGraph): void { + FAST_GROUPS_SERVICE.addFastGroupNode(this); + } + + override onRemoved(): void { + FAST_GROUPS_SERVICE.removeFastGroupNode(this); + } + + refreshWidgets() { + const canvas = app.canvas as TLGraphCanvas; + let sort = this.properties?.[PROPERTY_SORT] || "position"; + let customAlphabet: string[] | null = null; + if (sort === "custom alphabet") { + const customAlphaStr = this.properties?.[PROPERTY_SORT_CUSTOM_ALPHA]?.replace(/\n/g, ""); + if (customAlphaStr && customAlphaStr.trim()) { + customAlphabet = customAlphaStr.includes(",") + ? customAlphaStr.toLocaleLowerCase().split(",") + : customAlphaStr.toLocaleLowerCase().trim().split(""); + } + if (!customAlphabet?.length) { + sort = "alphanumeric"; + customAlphabet = null; + } + } + + const groups = [...FAST_GROUPS_SERVICE.getGroups(sort)]; + // The service will return pre-sorted groups for alphanumeric and position. If this node has a + // custom sort, then we need to sort it manually. + if (customAlphabet?.length) { + groups.sort((a, b) => { + let aIndex = -1; + let bIndex = -1; + // Loop and find indexes. As we're finding multiple, a single for loop is more efficient. + for (const [index, alpha] of customAlphabet!.entries()) { + aIndex = + aIndex < 0 ? (a.title.toLocaleLowerCase().startsWith(alpha) ? index : -1) : aIndex; + bIndex = + bIndex < 0 ? (b.title.toLocaleLowerCase().startsWith(alpha) ? index : -1) : bIndex; + if (aIndex > -1 && bIndex > -1) { + break; + } + } + // Now compare. + if (aIndex > -1 && bIndex > -1) { + const ret = aIndex - bIndex; + if (ret === 0) { + return a.title.localeCompare(b.title); + } + return ret; + } else if (aIndex > -1) { + return -1; + } else if (bIndex > -1) { + return 1; + } + return a.title.localeCompare(b.title); + }); + } + + // See if we're filtering by colors, and match against the built-in keywords and actuial hex + // values. + let filterColors = ( + (this.properties?.[PROPERTY_MATCH_COLORS] as string)?.split(",") || [] + ).filter((c) => c.trim()); + if (filterColors.length) { + filterColors = filterColors.map((color) => { + color = color.trim().toLocaleLowerCase(); + if (LGraphCanvas.node_colors[color]) { + color = LGraphCanvas.node_colors[color]!.groupcolor; + } + color = color.replace("#", "").toLocaleLowerCase(); + if (color.length === 3) { + color = color.replace(/(.)(.)(.)/, "$1$1$2$2$3$3"); + } + return `#${color}`; + }); + } + + // Go over the groups + let index = 0; + for (const group of groups) { + if (filterColors.length) { + let groupColor = group.color?.replace("#", "").trim().toLocaleLowerCase(); + if (!groupColor) { + continue; + } + if (groupColor.length === 3) { + groupColor = groupColor.replace(/(.)(.)(.)/, "$1$1$2$2$3$3"); + } + groupColor = `#${groupColor}`; + if (!filterColors.includes(groupColor)) { + continue; + } + } + if (this.properties?.[PROPERTY_MATCH_TITLE]?.trim()) { + try { + if (!new RegExp(this.properties[PROPERTY_MATCH_TITLE], "i").exec(group.title)) { + continue; + } + } catch (e) { + console.error(e); + continue; + } + } + const widgetName = `Enable ${group.title}`; + let widget = this.widgets.find((w) => w.name === widgetName); + if (!widget) { + // When we add a widget, litegraph is going to mess up the size, so we + // store it so we can retrieve it in computeSize. Hacky.. + this.tempSize = [...this.size]; + widget = this.addCustomWidget>({ + name: "RGTHREE_TOGGLE_AND_NAV", + label: "", + value: false, + disabled: false, + options: { on: "yes", off: "no" }, + draw: function ( + ctx: CanvasRenderingContext2D, + node: LGraphNode, + width: number, + posY: number, + height: number, + ) { + const widgetData = drawNodeWidget(ctx, { + width, + height, + posY, + }); + + const showNav = node.properties?.[PROPERTY_SHOW_NAV] !== false; + + // Render from right to left, since the text on left will take available space. + // `currentX` markes the current x position moving backwards. + let currentX = widgetData.width - widgetData.margin; + + // The nav arrow + if (!widgetData.lowQuality && showNav) { + currentX -= 7; // Arrow space margin + const midY = widgetData.posY + widgetData.height * 0.5; + ctx.fillStyle = ctx.strokeStyle = "#89A"; + ctx.lineJoin = "round"; + ctx.lineCap = "round"; + const arrow = new Path2D(`M${currentX} ${midY} l -7 6 v -3 h -7 v -6 h 7 v -3 z`); + ctx.fill(arrow); + ctx.stroke(arrow); + currentX -= 14; + + currentX -= 7; + ctx.strokeStyle = widgetData.colorOutline; + ctx.stroke(new Path2D(`M ${currentX} ${widgetData.posY} v ${widgetData.height}`)); + } else if (widgetData.lowQuality && showNav) { + currentX -= 28; + } + + // The toggle itself. + currentX -= 7; + ctx.fillStyle = this.value ? "#89A" : "#333"; + ctx.beginPath(); + const toggleRadius = height * 0.36; + ctx.arc(currentX - toggleRadius, posY + height * 0.5, toggleRadius, 0, Math.PI * 2); + ctx.fill(); + currentX -= toggleRadius * 2; + + if (!widgetData.lowQuality) { + currentX -= 4; + ctx.textAlign = "right"; + ctx.fillStyle = this.value ? widgetData.colorText : widgetData.colorTextSecondary; + const label = this.label || this.name; + const toggleLabelOn = this.options.on || "true"; + const toggleLabelOff = this.options.off || "false"; + ctx.fillText( + this.value ? toggleLabelOn : toggleLabelOff, + currentX, + posY + height * 0.7, + ); + currentX -= Math.max( + ctx.measureText(toggleLabelOn).width, + ctx.measureText(toggleLabelOff).width, + ); + + currentX -= 7; + ctx.textAlign = "left"; + let maxLabelWidth = + widgetData.width - widgetData.margin - 10 - (widgetData.width - currentX); + if (label != null) { + ctx.fillText( + fitString(ctx, label, maxLabelWidth), + widgetData.margin + 10, + posY + height * 0.7, + ); + } + } + }, + serializeValue(serializedNode: SerializedLGraphNode, widgetIndex: number) { + return this.value; + }, + mouse(event: PointerEvent, pos: Vector2, node: LGraphNode) { + if (event.type == "pointerdown") { + if ( + node.properties?.[PROPERTY_SHOW_NAV] !== false && + pos[0] >= node.size[0] - 15 - 28 - 1 + ) { + const canvas = app.canvas as TLGraphCanvas; + const lowQuality = (canvas.ds?.scale || 1) <= 0.5; + if (!lowQuality) { + // Clicked on right half with nav arrow, go to the group, center on group and set + // zoom to see it all. + canvas.centerOnNode(group); + const zoomCurrent = canvas.ds?.scale || 1; + const zoomX = canvas.canvas.width / group._size[0] - 0.02; + const zoomY = canvas.canvas.height / group._size[1] - 0.02; + canvas.setZoom(Math.min(zoomCurrent, zoomX, zoomY), [ + canvas.canvas.width / 2, + canvas.canvas.height / 2, + ]); + canvas.setDirty(true, true); + } + } else { + this.value = !this.value; + setTimeout(() => { + this.callback?.(this.value, app.canvas, node, pos, event); + }, 20); + } + } + return true; + }, + }); + (widget as any).doModeChange = (force?: boolean, skipOtherNodeCheck?: boolean) => { + group.recomputeInsideNodes(); + const hasAnyActiveNodes = group._nodes.some((n) => n.mode === LiteGraph.ALWAYS); + let newValue = force != null ? force : !hasAnyActiveNodes; + if (skipOtherNodeCheck !== true) { + if (newValue && this.properties?.[PROPERTY_RESTRICTION]?.includes(" one")) { + for (const widget of this.widgets) { + (widget as any).doModeChange(false, true); + } + } else if (!newValue && this.properties?.[PROPERTY_RESTRICTION] === "always one") { + newValue = this.widgets.every((w) => !w.value || w === widget); + } + } + for (const node of group._nodes) { + node.mode = (newValue ? this.modeOn : this.modeOff) as 1 | 2 | 3 | 4; + } + (group as any)._rgthreeHasAnyActiveNode = newValue; + widget!.value = newValue; + app.graph.setDirtyCanvas(true, false); + }; + widget.callback = () => { + (widget as any).doModeChange(); + }; + + this.setSize(this.computeSize()); + } + if (widget.name != widgetName) { + widget.name = widgetName; + this.setDirtyCanvas(true, false); + } + if (widget.value != (group as any)._rgthreeHasAnyActiveNode) { + widget.value = (group as any)._rgthreeHasAnyActiveNode; + this.setDirtyCanvas(true, false); + } + if (this.widgets[index] !== widget) { + const oldIndex = this.widgets.findIndex((w) => w === widget); + this.widgets.splice(index, 0, this.widgets.splice(oldIndex, 1)[0]!); + this.setDirtyCanvas(true, false); + } + index++; + } + + // Everything should now be in order, so let's remove all remaining widgets. + while ((this.widgets || [])[index]) { + this.removeWidget(index++); + } + } + + override computeSize(out?: Vector2) { + let size = super.computeSize(out); + if (this.tempSize) { + size[0] = Math.max(this.tempSize[0], size[0]); + size[1] = Math.max(this.tempSize[1], size[1]); + // We sometimes get repeated calls to compute size, so debounce before clearing. + this.debouncerTempWidth && clearTimeout(this.debouncerTempWidth); + this.debouncerTempWidth = setTimeout(() => { + this.tempSize = null; + }, 32); + } + setTimeout(() => { + app.graph.setDirtyCanvas(true, true); + }, 16); + return size; + } + + override async handleAction(action: string) { + if (action === "Mute all" || action === "Bypass all") { + const alwaysOne = this.properties?.[PROPERTY_RESTRICTION] === "always one"; + for (const [index, widget] of this.widgets.entries()) { + (widget as any)?.doModeChange(alwaysOne && !index ? true : false, true); + } + } else if (action === "Enable all") { + const onlyOne = this.properties?.[PROPERTY_RESTRICTION].includes(" one"); + for (const [index, widget] of this.widgets.entries()) { + (widget as any)?.doModeChange(onlyOne && index > 0 ? false : true, true); + } + } else if (action === "Toggle all") { + const onlyOne = this.properties?.[PROPERTY_RESTRICTION].includes(" one"); + let foundOne = false; + for (const [index, widget] of this.widgets.entries()) { + // If you have only one, then we'll stop at the first. + let newValue: boolean = onlyOne && foundOne ? false : !widget.value; + foundOne = foundOne || newValue; + (widget as any)?.doModeChange(newValue, true); + } + // And if you have always one, then we'll flip the last + if (!foundOne && this.properties?.[PROPERTY_RESTRICTION] === "always one") { + (this.widgets[this.widgets.length - 1] as any)?.doModeChange(true, true); + } + } + } + + override getHelp() { + return ` +

The ${this.type!.replace( + "(rgthree)", + "", + )} is an input-less node that automatically collects all groups in your current + workflow and allows you to quickly ${this.helpActions} all nodes within the group.

+
    +
  • +

    + Properties. You can change the following properties (by right-clicking + on the node, and select "Properties" or "Properties Panel" from the menu): +

    +
      +
    • + ${PROPERTY_MATCH_COLORS} - Only add groups that match the provided + colors. Can be ComfyUI colors (red, pale_blue) or hex codes (#a4d399). Multiple can be + added, comma delimited. +

    • +
    • + ${PROPERTY_MATCH_TITLE} - Filter the list of toggles by title match + (string match, or regular expression). +

    • +
    • + ${PROPERTY_SHOW_NAV} - Add / remove a quick navigation arrow to take you + to the group. (default: true) +

    • +
    • + ${PROPERTY_SORT} - Sort the toggles' order by "alphanumeric", graph + "position", or "custom alphabet". (default: "position") +

    • +
    • +

      + ${PROPERTY_SORT_CUSTOM_ALPHA} - When the + ${PROPERTY_SORT} property is "custom alphabet" you can define the + alphabet to use here, which will match the beginning of each group name and + sort against it. If group titles do not match any custom alphabet entry, then they + will be put after groups that do, ordered alphanumerically. +

      +

      + This can be a list of single characters, like "zyxw..." or comma delimited strings + for more control, like "sdxl,pro,sd,n,p". +

      +

      + Note, when two group title match the same custom alphabet entry, the normal + alphanumeric alphabet breaks the tie. For instance, a custom alphabet of + "e,s,d" will order groups names like "SDXL, SEGS, Detailer" eventhough the custom + alphabet has an "e" before "d" (where one may expect "SE" to be before "SD"). +

      +

      + To have "SEGS" appear before "SDXL" you can use longer strings. For instance, the + custom alphabet value of "se,s,f" would work here. +

      +
    • +
    • + ${PROPERTY_RESTRICTION} - Optionally, attempt to restrict the number of + widgets that can be enabled to a maximum of one, or always one. +

      +

      Note: If using "max one" or "always one" then this is only + enforced when clicking a toggle on this node; if nodes within groups are changed + outside of the initial toggle click, then these restriction will not be enforced, and + could result in a state where more than one toggle is enabled. This could also happen + if nodes are overlapped with multiple groups. +

    • + +
    +
  • +
`; + } +} + +/** + * Fast Bypasser implementation that looks for groups in the workflow and adds toggles to mute them. + */ +export class FastGroupsMuter extends BaseFastGroupsModeChanger { + static override type = NodeTypesString.FAST_GROUPS_MUTER; + static override title = NodeTypesString.FAST_GROUPS_MUTER; + override comfyClass = NodeTypesString.FAST_GROUPS_MUTER; + + static override exposedActions = ["Bypass all", "Enable all", "Toggle all"]; + + protected override helpActions = "mute and unmute"; + + override readonly modeOn: number = LiteGraph.ALWAYS; + override readonly modeOff: number = LiteGraph.NEVER; + + constructor(title = FastGroupsMuter.title) { + super(title); + this.onConstructed(); + } +} + +app.registerExtension({ + name: "rgthree.FastGroupsMuter", + registerCustomNodes() { + FastGroupsMuter.setUp(); + }, + loadedGraphNode(node: LGraphNode) { + if (node.type == FastGroupsMuter.title) { + (node as FastGroupsMuter).tempSize = [...node.size]; + } + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/feature_group_fast_toggle.ts b/rgthree-comfy/src_web/comfyui/feature_group_fast_toggle.ts new file mode 100644 index 0000000000000000000000000000000000000000..72a6bb72b4f1cf459ab22dd74126364424335079 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/feature_group_fast_toggle.ts @@ -0,0 +1,248 @@ +import type { + LGraphCanvas as TLGraphCanvas, + LGraphGroup as TLGraphGroup, + LGraph as TLGraph, + AdjustedMouseEvent, + Vector2, +} from "typings/litegraph.js"; +import type { AdjustedMouseCustomEvent } from "typings/rgthree.js"; + +import { app } from "scripts/app.js"; +import { rgthree } from "./rgthree.js"; +import { SERVICE as CONFIG_SERVICE } from "./services/config_service.js"; + +const BTN_SIZE = 20; +const BTN_MARGIN: Vector2 = [6, 4]; +const BTN_SPACING = 8; +const BTN_GRID = BTN_SIZE / 8; + +const TOGGLE_TO_MODE = new Map([ + ["MUTE", LiteGraph.NEVER], + ["BYPASS", 4], +]); + +/** + * Determines if the user clicked on an fast header icon. + */ +function clickedOnToggleButton(e: AdjustedMouseEvent, group: TLGraphGroup): string | null { + const toggles = CONFIG_SERVICE.getFeatureValue("group_header_fast_toggle.toggles"); + const pos = group.pos; + const size = group.size; + for (let i = 0; i < toggles.length; i++) { + const toggle = toggles[i]; + if ( + LiteGraph.isInsideRectangle( + e.canvasX, + e.canvasY, + pos[0] + size[0] - (BTN_SIZE + BTN_MARGIN[0]) * (i + 1), + pos[1] + BTN_MARGIN[1], + BTN_SIZE, + BTN_SIZE, + ) + ) { + return toggle; + } + } + return null; +} + +/** + * Registers the GroupHeaderToggles which places a mute and/or bypass icons in groups headers for + * quick, single-click ability to mute/bypass. + */ +app.registerExtension({ + name: "rgthree.GroupHeaderToggles", + async setup() { + /** + * Handles a click on the icon area if the user has the extension enable from settings. + * Hooks into the already overriden mouse down processor from rgthree. + */ + rgthree.addEventListener("on-process-mouse-down", ((e: AdjustedMouseCustomEvent) => { + if (!CONFIG_SERVICE.getFeatureValue("group_header_fast_toggle.enabled")) return; + + const canvas = app.canvas as TLGraphCanvas; + if (canvas.selected_group) { + const originalEvent = e.detail.originalEvent; + const group = canvas.selected_group; + const clickedOnToggle = clickedOnToggleButton(originalEvent, group) || ""; + const toggleMode = TOGGLE_TO_MODE.get(clickedOnToggle?.toLocaleUpperCase()); + if (toggleMode) { + group.recomputeInsideNodes(); + const hasAnyActiveNodes = group._nodes.some((n) => n.mode === LiteGraph.ALWAYS); + const isAllMuted = + !hasAnyActiveNodes && group._nodes.every((n) => n.mode === LiteGraph.NEVER); + const isAllBypassed = + !hasAnyActiveNodes && !isAllMuted && group._nodes.every((n) => n.mode === 4); + + let newMode: 0 | 1 | 2 | 3 | 4 = LiteGraph.ALWAYS; + if (toggleMode === LiteGraph.NEVER) { + newMode = isAllMuted ? LiteGraph.ALWAYS : LiteGraph.NEVER; + } else { + newMode = isAllBypassed ? LiteGraph.ALWAYS : 4; + } + for (const node of group._nodes) { + node.mode = newMode; + } + // Make it such that we're not then moving the group on drag. + canvas.selected_group = null; + canvas.dragging_canvas = false; + } + } + }) as EventListener); + + /** + * Overrides LiteGraph's Canvas method for drawingGroups and, after calling the original, checks + * that the user has enabled fast toggles and draws them on the top-right of the app.. + */ + const drawGroups = LGraphCanvas.prototype.drawGroups; + LGraphCanvas.prototype.drawGroups = function ( + canvasEl: HTMLCanvasElement, + ctx: CanvasRenderingContext2D, + ) { + drawGroups.apply(this, [...arguments] as any); + + if ( + !CONFIG_SERVICE.getFeatureValue("group_header_fast_toggle.enabled") || + !rgthree.lastAdjustedMouseEvent + ) { + return; + } + + const graph = app.graph as TLGraph; + + let groups: TLGraphGroup[]; + // Default to hover if not always. + if (CONFIG_SERVICE.getFeatureValue("group_header_fast_toggle.show") !== "always") { + const hoverGroup = graph.getGroupOnPos( + rgthree.lastAdjustedMouseEvent.canvasX, + rgthree.lastAdjustedMouseEvent.canvasY, + ); + groups = hoverGroup ? [hoverGroup] : []; + } else { + groups = graph._groups || []; + } + + if (!groups.length) { + return; + } + + const toggles = CONFIG_SERVICE.getFeatureValue("group_header_fast_toggle.toggles"); + + ctx.save(); + for (const group of groups || []) { + let anyActive = false; + let allMuted = !!group._nodes.length; + let allBypassed = allMuted; + + // Find the current state of the group's nodes. + for (const node of group._nodes) { + anyActive = anyActive || node.mode === LiteGraph.ALWAYS; + allMuted = allMuted && node.mode === LiteGraph.NEVER; + allBypassed = allBypassed && node.mode === 4; + if (anyActive || (!allMuted && !allBypassed)) { + break; + } + } + + // Display each toggle. + for (let i = 0; i < toggles.length; i++) { + const toggle = toggles[i]; + const on = toggle === "bypass" ? allBypassed : allMuted; + const pos = group._pos; + const size = group._size; + + ctx.fillStyle = ctx.strokeStyle = group.color || "#335"; + const x = pos[0] + size[0] - BTN_MARGIN[0] - BTN_SIZE - (BTN_SPACING + BTN_SIZE) * i; + const y = pos[1] + BTN_MARGIN[1]; + const midX = x + BTN_SIZE / 2; + const midY = y + BTN_SIZE / 2; + ctx.beginPath(); + ctx.lineJoin = "round"; + ctx.rect(x, y, BTN_SIZE, BTN_SIZE); + + ctx.lineWidth = 2; + if (toggle === "mute") { + ctx.lineJoin = "round"; + ctx.lineCap = "round"; + + if (on) { + ctx.stroke( + new Path2D(` + ${eyeFrame(midX, midY)} + ${eyeLashes(midX, midY)} + `), + ); + } else { + const radius = BTN_GRID * 1.5; + + // Eyeball fill + ctx.fill( + new Path2D(` + ${eyeFrame(midX, midY)} + ${eyeFrame(midX, midY, -1)} + ${circlePath(midX, midY, radius)} + ${circlePath(midX + BTN_GRID / 2, midY - BTN_GRID / 2, BTN_GRID * 0.375)} + `), + "evenodd", + ); + + // Eye Outline Stroke + ctx.stroke(new Path2D(`${eyeFrame(midX, midY)} ${eyeFrame(midX, midY, -1)}`)); + + // Eye lashes (faded) + ctx.globalAlpha = this.editor_alpha * 0.5; + ctx.stroke(new Path2D(`${eyeLashes(midX, midY)} ${eyeLashes(midX, midY, -1)}`)); + ctx.globalAlpha = this.editor_alpha; + } + } else { + const lineChanges = on + ? `a ${BTN_GRID * 3}, ${BTN_GRID * 3} 0 1, 1 ${BTN_GRID * 3 * 2},0 + l ${BTN_GRID * 2.0} 0` + : `l ${BTN_GRID * 8} 0`; + + ctx.stroke( + new Path2D(` + M ${x} ${midY} + ${lineChanges} + M ${x + BTN_SIZE} ${midY} l -2 2 + M ${x + BTN_SIZE} ${midY} l -2 -2 + `), + ); + ctx.fill(new Path2D(`${circlePath(x + BTN_GRID * 3, midY, BTN_GRID * 1.8)}`)); + } + } + } + ctx.restore(); + }; + }, +}); + +function eyeFrame(midX: number, midY: number, yFlip = 1) { + return ` + M ${midX - BTN_SIZE / 2} ${midY} + c ${BTN_GRID * 1.5} ${yFlip * BTN_GRID * 2.5}, ${BTN_GRID * (8 - 1.5)} ${ + yFlip * BTN_GRID * 2.5 + }, ${BTN_GRID * 8} 0 + `; +} + +function eyeLashes(midX: number, midY: number, yFlip = 1) { + return ` + M ${midX - BTN_GRID * 3.46} ${midY + yFlip * BTN_GRID * 0.9} l -1.15 ${1.25 * yFlip} + M ${midX - BTN_GRID * 2.38} ${midY + yFlip * BTN_GRID * 1.6} l -0.90 ${1.5 * yFlip} + M ${midX - BTN_GRID * 1.15} ${midY + yFlip * BTN_GRID * 1.95} l -0.50 ${1.75 * yFlip} + M ${midX + BTN_GRID * 0.0} ${midY + yFlip * BTN_GRID * 2.0} l 0.00 ${2.0 * yFlip} + M ${midX + BTN_GRID * 1.15} ${midY + yFlip * BTN_GRID * 1.95} l 0.50 ${1.75 * yFlip} + M ${midX + BTN_GRID * 2.38} ${midY + yFlip * BTN_GRID * 1.6} l 0.90 ${1.5 * yFlip} + M ${midX + BTN_GRID * 3.46} ${midY + yFlip * BTN_GRID * 0.9} l 1.15 ${1.25 * yFlip} +`; +} + +function circlePath(cx: number, cy: number, radius: number) { + return ` + M ${cx} ${cy} + m ${radius}, 0 + a ${radius},${radius} 0 1, 1 -${radius * 2},0 + a ${radius},${radius} 0 1, 1 ${radius * 2},0 + `; +} diff --git a/rgthree-comfy/src_web/comfyui/feature_import_individual_nodes.ts b/rgthree-comfy/src_web/comfyui/feature_import_individual_nodes.ts new file mode 100644 index 0000000000000000000000000000000000000000..4486bd8ac2092f6e8aeb58e9719b4ec2e1919c55 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/feature_import_individual_nodes.ts @@ -0,0 +1,66 @@ +import { tryToGetWorkflowDataFromEvent } from "rgthree/common/utils_workflow.js"; +import { app } from "scripts/app.js"; +import type { ComfyNode, ComfyNodeConstructor, ComfyObjectInfo } from "typings/comfy.js"; +import { SERVICE as CONFIG_SERVICE } from "./services/config_service.js"; + +/** + * Registers the GroupHeaderToggles which places a mute and/or bypass icons in groups headers for + * quick, single-click ability to mute/bypass. + */ +app.registerExtension({ + name: "rgthree.ImportIndividualNodes", + async beforeRegisterNodeDef(nodeType: ComfyNodeConstructor, nodeData: ComfyObjectInfo) { + const onDragOver = nodeType.prototype.onDragOver; + nodeType.prototype.onDragOver = function (e: DragEvent) { + let handled = onDragOver?.apply?.(this, [...arguments] as any); + if (handled != null) { + return handled; + } + return importIndividualNodesInnerOnDragOver(this, e); + }; + + const onDragDrop = nodeType.prototype.onDragDrop; + nodeType.prototype.onDragDrop = async function (e: DragEvent) { + const alreadyHandled = await onDragDrop?.apply?.(this, [...arguments] as any); + if (alreadyHandled) { + return alreadyHandled; + } + return importIndividualNodesInnerOnDragDrop(this, e); + }; + }, +}); + +export function importIndividualNodesInnerOnDragOver(node: ComfyNode, e: DragEvent): boolean { + return ( + (node.widgets?.length && !!CONFIG_SERVICE.getFeatureValue("import_individual_nodes.enabled")) || + false + ); +} + +export async function importIndividualNodesInnerOnDragDrop(node: ComfyNode, e: DragEvent) { + if (!node.widgets?.length || !CONFIG_SERVICE.getFeatureValue("import_individual_nodes.enabled")) { + return false; + } + + let handled = false; + const { workflow, prompt } = await tryToGetWorkflowDataFromEvent(e); + if (!handled && workflow) { + const exact = (workflow.nodes || []).find((n) => n.id === node.id && n.type === node.type); + if ( + exact && + exact.widgets_values?.length && + confirm( + "Found a node match from embedded workflow (same id & type) in this workflow. Would you like to set the widget values?", + ) + ) { + node.configure({ widgets_values: [...(exact?.widgets_values || [])] } as any); + handled = true; + } + } + if (!handled) { + handled = !confirm( + "No exact match found in workflow. Would you like to replace the whole workflow?", + ); + } + return handled; +} diff --git a/rgthree-comfy/src_web/comfyui/image_comparer.ts b/rgthree-comfy/src_web/comfyui/image_comparer.ts new file mode 100644 index 0000000000000000000000000000000000000000..b5ede313ce278c09ecf9fc4f55aa4c5b7b70f1b5 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/image_comparer.ts @@ -0,0 +1,474 @@ +import { app } from "scripts/app.js"; +import { api } from "scripts/api.js"; +import { RgthreeBaseServerNode } from "./base_node.js"; +import { NodeTypesString } from "./constants.js"; +import { ComfyNodeConstructor, ComfyObjectInfo } from "typings/comfy.js"; +import { + AdjustedMouseEvent, + LGraphCanvas, + LGraphNode, + SerializedLGraphNode, + Vector2, +} from "typings/litegraph.js"; +import { addConnectionLayoutSupport } from "./utils.js"; +import { + RgthreeBaseHitAreas, + RgthreeBaseWidget, + RgthreeBaseWidgetBounds, +} from "./utils_widgets.js"; +import { measureText } from "./utils_canvas.js"; + +type ComfyImageServerData = { filename: string; type: string; subfolder: string }; +type ComfyImageData = { name: string; selected: boolean; url: string; img?: HTMLImageElement }; +type OldExecutedPayload = { + images: ComfyImageServerData[]; +}; +type ExecutedPayload = { + a_images?: ComfyImageServerData[]; + b_images?: ComfyImageServerData[]; +}; + +function imageDataToUrl(data: ComfyImageServerData) { + return api.apiURL( + `/view?filename=${encodeURIComponent(data.filename)}&type=${data.type}&subfolder=${ + data.subfolder + }${app.getPreviewFormatParam()}${app.getRandParam()}`, + ); +} + +/** + * Compares two images in one canvas node. + */ +export class RgthreeImageComparer extends RgthreeBaseServerNode { + static override title = NodeTypesString.IMAGE_COMPARER; + static override type = NodeTypesString.IMAGE_COMPARER; + static comfyClass = NodeTypesString.IMAGE_COMPARER; + + // These is what the core preview image node uses to show the context menu. May not be that helpful + // since it likely will always be "0" when a context menu is invoked without manually changing + // something. + imageIndex: number = 0; + imgs: InstanceType[] = []; + + override serialize_widgets = true; + + isPointerDown = false; + isPointerOver = false; + pointerOverPos: Vector2 = [0, 0]; + + private canvasWidget: RgthreeImageComparerWidget | null = null; + + static "@comparer_mode" = { + type: "combo", + values: ["Slide", "Click"], + }; + + constructor(title = RgthreeImageComparer.title) { + super(title); + this.properties["comparer_mode"] = "Slide"; + } + + override onExecuted(output: ExecutedPayload | OldExecutedPayload) { + super.onExecuted?.(output); + if ("images" in output) { + this.canvasWidget!.value = { + images: (output.images || []).map((d, i) => { + return { + name: i === 0 ? "A" : "B", + selected: true, + url: imageDataToUrl(d), + }; + }), + }; + } else { + output.a_images = output.a_images || []; + output.b_images = output.b_images || []; + const imagesToChoose: ComfyImageData[] = []; + const multiple = output.a_images.length + output.b_images.length > 2; + for (const [i, d] of output.a_images.entries()) { + imagesToChoose.push({ + name: output.a_images.length > 1 || multiple ? `A${i + 1}` : "A", + selected: i === 0, + url: imageDataToUrl(d), + }); + } + for (const [i, d] of output.b_images.entries()) { + imagesToChoose.push({ + name: output.b_images.length > 1 || multiple ? `B${i + 1}` : "B", + selected: i === 0, + url: imageDataToUrl(d), + }); + } + this.canvasWidget!.value = { images: imagesToChoose }; + } + } + + override onSerialize(o: SerializedLGraphNode) { + super.onSerialize && super.onSerialize(o); + for (let [index, widget_value] of (o.widgets_values || []).entries()) { + if (this.widgets[index]?.name === "rgthree_comparer") { + o.widgets_values![index] = ( + this.widgets[index] as RgthreeImageComparerWidget + ).value.images.map((d) => { + d = { ...d }; + delete d.img; + return d; + }); + } + } + } + + override onNodeCreated() { + this.canvasWidget = this.addCustomWidget( + new RgthreeImageComparerWidget("rgthree_comparer", this), + ); + this.setSize(this.computeSize()); + this.setDirtyCanvas(true, true); + } + + /** + * Sets mouse as down or up based on param. If it's down, we also loop to check pointer is still + * down. This is because LiteGraph doesn't fire `onMouseUp` every time there's a mouse up, so we + * need to manually monitor `pointer_is_down` and, when it's no longer true, set mouse as up here. + */ + private setIsPointerDown(down: boolean = this.isPointerDown) { + const newIsDown = down && !!app.canvas.pointer_is_down; + if (this.isPointerDown !== newIsDown) { + this.isPointerDown = newIsDown; + this.setDirtyCanvas(true, false); + } + this.imageIndex = this.isPointerDown ? 1 : 0; + if (this.isPointerDown) { + requestAnimationFrame(() => { + this.setIsPointerDown(); + }); + } + } + + override onMouseDown(event: MouseEvent, pos: Vector2, graphCanvas: LGraphCanvas): void { + super.onMouseDown?.(event, pos, graphCanvas); + this.setIsPointerDown(true); + } + + override onMouseEnter(event: MouseEvent, pos: Vector2, graphCanvas: LGraphCanvas): void { + super.onMouseEnter?.(event, pos, graphCanvas); + this.setIsPointerDown(!!app.canvas.pointer_is_down); + this.isPointerOver = true; + } + + override onMouseLeave(event: MouseEvent, pos: Vector2, graphCanvas: LGraphCanvas): void { + super.onMouseLeave?.(event, pos, graphCanvas); + this.setIsPointerDown(false); + this.isPointerOver = false; + } + + override onMouseMove(event: MouseEvent, pos: Vector2, graphCanvas: LGraphCanvas): void { + super.onMouseMove?.(event, pos, graphCanvas); + this.pointerOverPos = [...pos]; + this.imageIndex = this.pointerOverPos[0] > this.size[0] / 2 ? 1 : 0; + } + + override getHelp(): string { + return ` +

+ The ${this.type!.replace("(rgthree)", "")} node compares two images on top of each other. +

+
    +
  • +

    + Notes +

    +
      +
    • + The right-click menu may show image options (Open Image, Save Image, etc.) which will + correspond to the first image (image_a) if clicked on the left-half of the node, or + the second image if on the right half of the node. +

    • +
    +
  • +
  • +

    + Inputs +

    +
      +
    • + image_a Optional. The first image to use to compare. + image_a. +

    • +
    • + image_b Optional. The second image to use to compare. +

    • +
    • + Note image_a and image_b work best when a single + image is provided. However, if each/either are a batch, you can choose which item + from each batch are chosen to be compared. If either image_a or + image_b are not provided, the node will choose the first two from the + provided input if it's a batch, otherwise only show the single image (just as + Preview Image would). +

    • +
    +
  • +
  • +

    + Properties. You can change the following properties (by right-clicking + on the node, and select "Properties" or "Properties Panel" from the menu): +

    +
      +
    • + comparer_mode - Choose between "Slide" and "Click". Defaults to "Slide". +

    • +
    +
  • +
`; + } + + static override setUp(comfyClass: ComfyNodeConstructor, nodeData: ComfyObjectInfo) { + RgthreeBaseServerNode.registerForOverride(comfyClass, nodeData, RgthreeImageComparer); + } + + static override onRegisteredForOverride(comfyClass: any) { + addConnectionLayoutSupport(RgthreeImageComparer, app, [ + ["Left", "Right"], + ["Right", "Left"], + ]); + setTimeout(() => { + RgthreeImageComparer.category = comfyClass.category; + }); + } +} + +type RgthreeImageComparerWidgetValue = { + images: ComfyImageData[]; +}; + +class RgthreeImageComparerWidget extends RgthreeBaseWidget { + private node: RgthreeImageComparer; + + protected override hitAreas: RgthreeBaseHitAreas = { + // We dynamically set this when/if we draw the labels. + }; + + private selected: [ComfyImageData?, ComfyImageData?] = []; + + constructor(name: string, node: RgthreeImageComparer) { + super(name); + this.node = node; + } + + private _value: RgthreeImageComparerWidgetValue = { images: [] }; + + set value(v: RgthreeImageComparerWidgetValue) { + // Despite `v` typed as RgthreeImageComparerWidgetValue, we may have gotten an array of strings + // from previous versions. We can handle that gracefully. + let cleanedVal; + if (Array.isArray(v)) { + cleanedVal = v.map((d, i) => { + if (!d || typeof d === "string") { + // We usually only have two here, so they're selected. + d = { url: d, name: i == 0 ? "A" : "B", selected: true }; + } + return d; + }); + } else { + cleanedVal = v.images || []; + } + + // If we have multiple items in our sent value but we don't have both an "A" and a "B" then + // just simplify it down to the first two in the list. + if (cleanedVal.length > 2) { + const hasAAndB = + cleanedVal.some((i) => i.name.startsWith("A")) && + cleanedVal.some((i) => i.name.startsWith("B")); + if (!hasAAndB) { + cleanedVal = [cleanedVal[0], cleanedVal[1]]; + } + } + + let selected = cleanedVal.filter((d) => d.selected); + // None are selected. + if (!selected.length && cleanedVal.length) { + cleanedVal[0]!.selected = true; + } + + selected = cleanedVal.filter((d) => d.selected); + if (selected.length === 1 && cleanedVal.length > 1) { + cleanedVal.find((d) => !d.selected)!.selected = true; + } + + this._value.images = cleanedVal; + + selected = cleanedVal.filter((d) => d.selected); + this.setSelected(selected as [ComfyImageData, ComfyImageData]); + } + + get value() { + return this._value; + } + + setSelected(selected: [ComfyImageData, ComfyImageData]) { + this._value.images.forEach((d) => (d.selected = false)); + this.node.imgs.length = 0; + for (const sel of selected) { + if (!sel.img) { + sel.img = new Image(); + sel.img.src = sel.url; + this.node.imgs.push(sel.img); + } + sel.selected = true; + } + this.selected = selected; + } + + draw(ctx: CanvasRenderingContext2D, node: RgthreeImageComparer, width: number, y: number) { + this.hitAreas = {}; + if (this.value.images.length > 2) { + ctx.textAlign = "left"; + ctx.textBaseline = "top"; + ctx.font = `14px Arial`; + // Let's calculate the widths of all the labels. + const drawData: any = []; + const spacing = 5; + let x = 0; + for (const img of this.value.images) { + const width = measureText(ctx, img.name); + drawData.push({ + img, + text: img.name, + x, + width: measureText(ctx, img.name), + }); + x += width + spacing; + } + x = (node.size[0] - (x - spacing)) / 2; + for (const d of drawData) { + ctx.fillStyle = d.img.selected ? "rgba(180, 180, 180, 1)" : "rgba(180, 180, 180, 0.5)"; + ctx.fillText(d.text, x, y); + this.hitAreas[d.text] = { + bounds: [x, y, d.width, 14], + data: d.img, + onDown: this.onSelectionDown, + }; + x += d.width + spacing; + } + y += 20; + } + + if (node.properties?.["comparer_mode"] === "Click") { + this.drawImage(ctx, this.selected[this.node.isPointerDown ? 1 : 0], y); + } else { + this.drawImage(ctx, this.selected[0], y); + if (node.isPointerOver) { + this.drawImage(ctx, this.selected[1], y, this.node.pointerOverPos[0]); + } + } + } + + private onSelectionDown( + event: AdjustedMouseEvent, + pos: Vector2, + node: LGraphNode, + bounds?: RgthreeBaseWidgetBounds, + ) { + const selected = [...this.selected]; + if (bounds?.data.name.startsWith("A")) { + selected[0] = bounds.data; + } else if (bounds?.data.name.startsWith("B")) { + selected[1] = bounds.data; + } + this.setSelected(selected as [ComfyImageData, ComfyImageData]); + } + + private drawImage( + ctx: CanvasRenderingContext2D, + image: ComfyImageData | undefined, + y: number, + cropX?: number, + ) { + if (!image?.img?.naturalWidth || !image?.img?.naturalHeight) { + return; + } + let [nodeWidth, nodeHeight] = this.node.size; + const imageAspect = image?.img.naturalWidth / image?.img.naturalHeight; + let height = nodeHeight - y; + const widgetAspect = nodeWidth / height; + let targetWidth, targetHeight; + let offsetX = 0; + if (imageAspect > widgetAspect) { + targetWidth = nodeWidth; + targetHeight = nodeWidth / imageAspect; + } else { + targetHeight = height; + targetWidth = height * imageAspect; + offsetX = (nodeWidth - targetWidth) / 2; + } + const widthMultiplier = image?.img.naturalWidth / targetWidth; + + const sourceX = 0; + const sourceY = 0; + const sourceWidth = + cropX != null ? (cropX - offsetX) * widthMultiplier : image?.img.naturalWidth; + const sourceHeight = image?.img.naturalHeight; + const destX = (nodeWidth - targetWidth) / 2; + const destY = y + (height - targetHeight) / 2; + const destWidth = cropX != null ? cropX - offsetX : targetWidth; + const destHeight = targetHeight; + ctx.save(); + ctx.beginPath(); + let globalCompositeOperation = ctx.globalCompositeOperation; + if (cropX) { + ctx.rect(destX, destY, destWidth, destHeight); + ctx.clip(); + } + ctx.drawImage( + image?.img, + sourceX, + sourceY, + sourceWidth, + sourceHeight, + destX, + destY, + destWidth, + destHeight, + ); + // Shows a label overlayed on the image. Not perfect, keeping commented out. + // ctx.globalCompositeOperation = "difference"; + // ctx.fillStyle = "rgba(180, 180, 180, 1)"; + // ctx.textAlign = "center"; + // ctx.font = `32px Arial`; + // ctx.fillText(image.name, nodeWidth / 2, y + 32); + if (cropX != null && cropX >= (nodeWidth - targetWidth) / 2 && cropX <= targetWidth + offsetX) { + ctx.beginPath(); + ctx.moveTo(cropX, destY); + ctx.lineTo(cropX, destY + destHeight); + ctx.globalCompositeOperation = "difference"; + ctx.strokeStyle = "rgba(255,255,255, 1)"; + ctx.stroke(); + } + ctx.globalCompositeOperation = globalCompositeOperation; + ctx.restore(); + } + + computeSize(width: number): Vector2 { + return [width, 20]; + } + + serializeValue(serializedNode: SerializedLGraphNode, widgetIndex: number) { + const v = []; + for (const data of this._value.images) { + // Remove the img since it can't serialize. + const d = { ...data }; + delete d.img; + v.push(d); + } + return { images: v }; + } +} + +app.registerExtension({ + name: "rgthree.ImageComparer", + async beforeRegisterNodeDef(nodeType: ComfyNodeConstructor, nodeData: ComfyObjectInfo) { + if (nodeData.name === RgthreeImageComparer.type) { + RgthreeImageComparer.setUp(nodeType, nodeData); + } + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/image_inset_crop.ts b/rgthree-comfy/src_web/comfyui/image_inset_crop.ts new file mode 100644 index 0000000000000000000000000000000000000000..9d792b224041cf93b81d269658078ea227a033d9 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/image_inset_crop.ts @@ -0,0 +1,73 @@ +import { app } from "scripts/app.js"; +import type { ComfyApp, ComfyNodeConstructor, ComfyObjectInfo } from "typings/comfy.js"; +import { RgthreeBaseServerNode } from "./base_node.js"; +import { LGraph, LGraphNode, SerializedLGraphNode } from "typings/litegraph.js"; +import { NodeTypesString } from "./constants.js"; + +class ImageInsetCrop extends RgthreeBaseServerNode { + static override title = NodeTypesString.IMAGE_INSET_CROP; + static override type = NodeTypesString.IMAGE_INSET_CROP; + static comfyClass = NodeTypesString.IMAGE_INSET_CROP; + + static override exposedActions = ["Reset Crop"]; + static maxResolution = 8192; + + constructor(title = ImageInsetCrop.title) { + super(title); + } + + override onAdded(graph: LGraph): void { + const measurementWidget = this.widgets[0]!; + let callback = measurementWidget.callback; + measurementWidget.callback = (...args) => { + this.setWidgetStep(); + callback && callback.apply(measurementWidget, [...args]); + }; + this.setWidgetStep(); + } + + override configure(info: SerializedLGraphNode): void { + super.configure(info); + this.setWidgetStep(); + } + + private setWidgetStep() { + const measurementWidget = this.widgets[0]!; + for (let i = 1; i <= 4; i++) { + if (measurementWidget.value === "Pixels") { + this.widgets[i]!.options.step = 80; + this.widgets[i]!.options.max = ImageInsetCrop.maxResolution; + } else { + this.widgets[i]!.options.step = 10; + this.widgets[i]!.options.max = 99; + } + } + } + + override async handleAction(action: string): Promise { + if (action === "Reset Crop") { + for (const widget of this.widgets) { + if (["left", "right", "top", "bottom"].includes(widget.name!)) { + widget.value = 0; + } + } + } + } + + static override setUp(comfyClass: ComfyNodeConstructor, nodeData: ComfyObjectInfo) { + RgthreeBaseServerNode.registerForOverride(comfyClass, nodeData, ImageInsetCrop); + } +} + +app.registerExtension({ + name: "rgthree.ImageInsetCrop", + async beforeRegisterNodeDef( + nodeType: ComfyNodeConstructor, + nodeData: ComfyObjectInfo, + _app: ComfyApp, + ) { + if (nodeData.name === NodeTypesString.IMAGE_INSET_CROP) { + ImageInsetCrop.setUp(nodeType, nodeData); + } + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/label.ts b/rgthree-comfy/src_web/comfyui/label.ts new file mode 100644 index 0000000000000000000000000000000000000000..6b10c385ceee94ed4cadad509a380c1450276c5a --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/label.ts @@ -0,0 +1,207 @@ +import { app } from "scripts/app.js"; +import { RgthreeBaseVirtualNodeConstructor } from "typings/rgthree.js"; +import { RgthreeBaseVirtualNode } from "./base_node.js"; +import { NodeTypesString } from "./constants.js"; +import type { + LGraphCanvas as TLGraphCanvas, + LGraphNode, + AdjustedMouseEvent, + Vector2, +} from "typings/litegraph.js"; +import { rgthree } from "./rgthree.js"; + +/** + * A label node that allows you to put floating text anywhere on the graph. The text is the `Title` + * and the font size, family, color, alignment as well as a background color, padding, and + * background border radius can all be adjusted in the properties. Multiline text can be added from + * the properties panel (because ComfyUI let's you shift + enter there, only). + */ +export class Label extends RgthreeBaseVirtualNode { + static override type = NodeTypesString.LABEL; + static override title = NodeTypesString.LABEL; + override comfyClass = NodeTypesString.LABEL; + + static readonly title_mode = LiteGraph.NO_TITLE; + static collapsable = false; + + static "@fontSize" = { type: "number" }; + static "@fontFamily" = { type: "string" }; + static "@fontColor" = { type: "string" }; + static "@textAlign" = { type: "combo", values: ["left", "center", "right"] }; + static "@backgroundColor" = { type: "string" }; + static "@padding" = { type: "number" }; + static "@borderRadius" = { type: "number" }; + + override resizable = false; + + constructor(title = Label.title) { + super(title); + this.properties["fontSize"] = 12; + this.properties["fontFamily"] = "Arial"; + this.properties["fontColor"] = "#ffffff"; + this.properties["textAlign"] = "left"; + this.properties["backgroundColor"] = "transparent"; + this.properties["padding"] = 0; + this.properties["borderRadius"] = 0; + this.color = "#fff0"; + this.bgcolor = "#fff0"; + + this.onConstructed(); + } + + draw(ctx: CanvasRenderingContext2D) { + this.flags = this.flags || {}; + this.flags.allow_interaction = !this.flags.pinned; + ctx.save(); + this.color = "#fff0"; + this.bgcolor = "#fff0"; + const fontColor = this.properties["fontColor"] || "#ffffff"; + const backgroundColor = this.properties["backgroundColor"] || ""; + ctx.font = `${Math.max(this.properties["fontSize"] || 0, 1)}px ${ + this.properties["fontFamily"] ?? "Arial" + }`; + const padding = Number(this.properties["padding"]) ?? 0; + + const lines = this.title.replace(/\n*$/, "").split("\n"); + const maxWidth = Math.max(...lines.map((s) => ctx.measureText(s).width)); + this.size[0] = maxWidth + padding * 2; + this.size[1] = this.properties["fontSize"] * lines.length + padding * 2; + if (backgroundColor) { + ctx.beginPath(); + const borderRadius = Number(this.properties["borderRadius"]) || 0; + ctx.roundRect(0, 0, this.size[0], this.size[1], [borderRadius]); + ctx.fillStyle = backgroundColor; + ctx.fill(); + } + ctx.textAlign = "left"; + let textX = padding; + if (this.properties["textAlign"] === "center") { + ctx.textAlign = "center"; + textX = this.size[0] / 2; + } else if (this.properties["textAlign"] === "right") { + ctx.textAlign = "right"; + textX = this.size[0] - padding; + } + ctx.textBaseline = "top"; + ctx.fillStyle = fontColor; + let currentY = padding; + for (let i = 0; i < lines.length; i++) { + ctx.fillText(lines[i] || " ", textX, currentY); + currentY += this.properties["fontSize"]; + } + ctx.restore(); + } + + override onDblClick(event: AdjustedMouseEvent, pos: Vector2, canvas: TLGraphCanvas) { + // Since everything we can do here is in the properties, let's pop open the properties panel. + LGraphCanvas.active_canvas.showShowNodePanel(this); + } + + override onShowCustomPanelInfo(panel: HTMLElement) { + panel.querySelector('div.property[data-property="Mode"]')?.remove(); + panel.querySelector('div.property[data-property="Color"]')?.remove(); + } + + override inResizeCorner(x: number, y: number) { + // A little ridiculous there's both a resizable property and this method separately to draw the + // resize icon... + return this.resizable; + } + + override getHelp() { + return ` +

+ The rgthree-comfy ${this.type!.replace("(rgthree)", "")} node allows you to add a floating + label to your workflow. +

+

+ The text shown is the "Title" of the node and you can adjust the the font size, font family, + font color, text alignment as well as a background color, padding, and background border + radius from the node's properties. You can double-click the node to open the properties + panel. +

+

    +
  • +

    + Pro tip #1: You can add multiline text from the properties panel + (because ComfyUI let's you shift + enter there, only). +

    +
  • +
  • +

    + Pro tip #2: You can use ComfyUI's native "pin" option in the + right-click menu to make the label stick to the workflow and clicks to "go through". + You can right-click at any time to unpin. +

    +
  • +
  • +

    + Pro tip #3: Color values are hexidecimal strings, like "#FFFFFF" for + white, or "#660000" for dark red. You can supply a 7th & 8th value (or 5th if using + shorthand) to create a transluscent color. For instance, "#FFFFFF88" is semi-transparent + white. +

    +
  • +
`; + } +} + +/** + * We override the drawNode to see if we're drawing our label and, if so, hijack it so we can draw + * it like we want. We also do call out to oldDrawNode, which takes care of very minimal things, + * like a select box. + */ +const oldDrawNode = LGraphCanvas.prototype.drawNode; +LGraphCanvas.prototype.drawNode = function (node: LGraphNode, ctx: CanvasRenderingContext2D) { + if (node.constructor === Label) { + // These get set very aggressively; maybe an extension is doing it. We'll just clear them out + // each time. + (node as Label).bgcolor = "transparent"; + (node as Label).color = "transparent"; + const v = oldDrawNode.apply(this, arguments as any); + (node as Label).draw(ctx); + return v; + } + + const v = oldDrawNode.apply(this, arguments as any); + return v; +}; + +/** + * We override LGraph getNodeOnPos to see if we're being called while also processing a mouse down + * and, if so, filter out any label nodes on labels that are pinned. This makes the click go + * "through" the label. We still allow right clicking (so you can unpin) and double click for the + * properties panel, though that takes two double clicks (one to select, one to actually double + * click). + */ +const oldGetNodeOnPos = LGraph.prototype.getNodeOnPos; +LGraph.prototype.getNodeOnPos = function ( + x: number, + y: number, + nodes_list?: LGraphNode[], + margin?: number, +) { + if ( + // processMouseDown always passes in the nodes_list + nodes_list && + rgthree.processingMouseDown && + rgthree.lastAdjustedMouseEvent?.type.includes("down") && + rgthree.lastAdjustedMouseEvent?.which === 1 + ) { + // Using the same logic from LGraphCanvas processMouseDown, let's see if we consider this a + // double click. + let isDoubleClick = LiteGraph.getTime() - LGraphCanvas.active_canvas.last_mouseclick < 300; + if (!isDoubleClick) { + nodes_list = [...nodes_list].filter((n) => !(n instanceof Label) || !n.flags?.pinned); + } + } + return oldGetNodeOnPos.apply(this, [x, y, nodes_list, margin]) as T | null; +}; + +// Register the extension. +app.registerExtension({ + name: "rgthree.Label", + registerCustomNodes() { + Label.setUp(); + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/menu_auto_nest.ts b/rgthree-comfy/src_web/comfyui/menu_auto_nest.ts new file mode 100644 index 0000000000000000000000000000000000000000..2c06ec523c8d4ad09301d0a7c03ff04cba012b4a --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/menu_auto_nest.ts @@ -0,0 +1,143 @@ +import { app } from "scripts/app.js"; +import type { + ContextMenuItem, + LGraphNode, + ContextMenu, + IContextMenuOptions, +} from "typings/litegraph.js"; +import { rgthree } from "./rgthree.js"; +import { SERVICE as CONFIG_SERVICE } from "./services/config_service.js"; + +const SPECIAL_ENTRIES = [/^(CHOOSE|NONE|DISABLE|OPEN)(\s|$)/i, /^\p{Extended_Pictographic}/gu]; + +/** + * Handles a large, flat list of string values given ContextMenu and breaks it up into subfolder, if + * they exist. This is experimental and initially built to work for CheckpointLoaderSimple. + */ +app.registerExtension({ + name: "rgthree.ContextMenuAutoNest", + async setup() { + const logger = rgthree.newLogSession("[ContextMenuAutoNest]"); + + const existingContextMenu = LiteGraph.ContextMenu; + + // @ts-ignore: TypeScript doesn't like this override. + LiteGraph.ContextMenu = function (values: ContextMenuItem[], options: IContextMenuOptions) { + const threshold = CONFIG_SERVICE.getConfigValue("features.menu_auto_nest.threshold", 20); + const enabled = CONFIG_SERVICE.getConfigValue("features.menu_auto_nest.subdirs", false); + + // If we're not enabled, or are incompatible, then just call out safely. + let incompatible: string | boolean = !enabled || !!options?.extra?.rgthree_doNotNest; + if (!incompatible) { + if (values.length <= threshold) { + incompatible = `Skipping context menu auto nesting b/c threshold is not met (${threshold})`; + } + // If there's a rgthree_originalCallback, then we're nested and don't need to check things + // we only expect on the first nesting. + if (!options.parentMenu?.options.rgthree_originalCallback) { + // On first context menu, we require a callback and a flat list of options as strings. + if (!options?.callback) { + incompatible = `Skipping context menu auto nesting b/c a callback was expected.`; + } else if (values.some((i) => typeof i !== "string")) { + incompatible = `Skipping context menu auto nesting b/c not all values were strings.`; + } + } + } + if (incompatible) { + if (enabled) { + const [n, v] = logger.infoParts( + "Skipping context menu auto nesting for incompatible menu.", + ); + console[n]?.(...v); + } + return existingContextMenu.apply(this as any, [...arguments] as any); + } + + const folders: { [key: string]: ContextMenuItem[] } = {}; + const specialOps: ContextMenuItem[] = []; + const folderless: ContextMenuItem[] = []; + for (const value of values) { + if (!value) { + folderless.push(value); + continue; + } + const newValue = typeof value === "string" ? { content: value } : Object.assign({}, value); + newValue.rgthree_originalValue = value.rgthree_originalValue || value; + const valueContent = newValue.content || ''; + const splitBy = valueContent.indexOf("/") > -1 ? "/" : "\\"; + const valueSplit = valueContent.split(splitBy); + if (valueSplit.length > 1) { + const key = valueSplit.shift()!; + newValue.content = valueSplit.join(splitBy); + folders[key] = folders[key] || []; + folders[key]!.push(newValue); + } else if (SPECIAL_ENTRIES.some((r) => r.test(valueContent))) { + specialOps.push(newValue); + } else { + folderless.push(newValue); + } + } + + const foldersCount = Object.values(folders).length; + if (foldersCount > 0) { + // Propogate the original callback down through the options. + options.rgthree_originalCallback = + options.rgthree_originalCallback || + options.parentMenu?.options.rgthree_originalCallback || + options.callback; + const oldCallback = options.rgthree_originalCallback; + options.callback = undefined; + const newCallback = ( + item: ContextMenuItem, + options: IContextMenuOptions, + event: MouseEvent, + parentMenu: ContextMenu | undefined, + node: LGraphNode, + ) => { + oldCallback?.(item?.rgthree_originalValue!, options, event, undefined, node); + }; + const [n, v] = logger.infoParts(`Nested folders found (${foldersCount}).`); + console[n]?.(...v); + const newValues: ContextMenuItem[] = []; + for (const [folderName, folderValues] of Object.entries(folders)) { + newValues.push({ + content: `📁 ${folderName}`, + has_submenu: true, + callback: () => { + /* no-op, use the item callback. */ + }, + submenu: { + options: folderValues.map((value) => { + value!.callback = newCallback; + return value; + }), + }, + }); + } + values = ([] as ContextMenuItem[]).concat( + specialOps.map((f) => { + if (typeof f === "string") { + f = { content: f }; + } + f!.callback = newCallback; + return f; + }), + newValues, + folderless.map((f) => { + if (typeof f === "string") { + f = { content: f }; + } + f!.callback = newCallback; + return f; + }), + ); + } + if (options.scale == null) { + options.scale = Math.max(app.canvas.ds?.scale || 1, 1); + } + return existingContextMenu.call(this as any, values, options); + }; + + LiteGraph.ContextMenu.prototype = existingContextMenu.prototype; + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/menu_copy_image.ts b/rgthree-comfy/src_web/comfyui/menu_copy_image.ts new file mode 100644 index 0000000000000000000000000000000000000000..5448ee41d5195da3b4b8d732465bc0c466ff64a9 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/menu_copy_image.ts @@ -0,0 +1,73 @@ +import { app } from "scripts/app.js"; +import type { LGraphCanvas, ContextMenuItem } from "typings/litegraph.js"; +import type { ComfyNodeConstructor, ComfyObjectInfo } from "typings/comfy.js"; + +const clipboardSupportedPromise = new Promise(async (resolve) => { + try { + // MDN says to check this, but it doesn't work in Mozilla... however, in secure contexts + // (localhost included), it's given by default if the user has it flagged.. so we should be + // able to check in the latter ClipboardItem too. + const result = await navigator.permissions.query({ name: "clipboard-write" } as any); + resolve(result.state === "granted"); + return; + } catch (e) { + try { + if (!navigator.clipboard.write) { + throw new Error(); + } + new ClipboardItem({ "image/png": new Blob([], { type: "image/png" }) }); + resolve(true); + return; + } catch (e) { + resolve(false); + } + } +}); + +/** + * Adds a "Copy Image" to images in similar fashion to the "native" Open Image and Save Image + * options. + */ +app.registerExtension({ + name: "rgthree.CopyImageToClipboard", + async beforeRegisterNodeDef(nodeType: ComfyNodeConstructor, nodeData: ComfyObjectInfo) { + if (nodeData.name.toLowerCase().includes("image")) { + if (await clipboardSupportedPromise) { + const getExtraMenuOptions = nodeType.prototype.getExtraMenuOptions; + nodeType.prototype.getExtraMenuOptions = function ( + canvas: LGraphCanvas, + options: ContextMenuItem[], + ) { + getExtraMenuOptions ? getExtraMenuOptions.apply(this, arguments) : undefined; + // If we already have a copy image somehow, then let's skip ours. + if (this.imgs?.length) { + let img = + this.imgs[this.imageIndex || 0] || this.imgs[this.overIndex || 0] || this.imgs[0]; + const foundIdx = options.findIndex((option) => option?.content?.includes("Copy Image")); + if (img && foundIdx === -1) { + const menuItem: ContextMenuItem = { + content: "Copy Image (rgthree)", + callback: () => { + const canvas = document.createElement("canvas"); + const ctx = canvas.getContext("2d")!; + canvas.width = img.naturalWidth; + canvas.height = img.naturalHeight; + ctx.drawImage(img, 0, 0, img.naturalWidth, img.naturalHeight); + canvas.toBlob((blob) => { + navigator.clipboard.write([new ClipboardItem({ "image/png": blob! })]); + }); + }, + }; + let idx = options.findIndex((option) => option?.content?.includes("Open Image")) + 1; + if (idx != null) { + options.splice(idx, 0, menuItem); + } else { + options.unshift(menuItem); + } + } + } + }; + } + } + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/menu_queue_node.ts b/rgthree-comfy/src_web/comfyui/menu_queue_node.ts new file mode 100644 index 0000000000000000000000000000000000000000..4aeaa61611384080fa12891de70c66d890486d48 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/menu_queue_node.ts @@ -0,0 +1,97 @@ +import { app } from "scripts/app.js"; +import type { + LGraphCanvas as TLGraphCanvas, + ContextMenuItem, + LGraphNode, +} from "typings/litegraph.js"; +import type { ComfyNodeConstructor, ComfyObjectInfo } from "typings/comfy.js"; +import { rgthree } from "./rgthree.js"; +import { SERVICE as CONFIG_SERVICE } from "./services/config_service.js"; + +function getOutputNodes(nodes: LGraphNode[]) { + return ( + nodes?.filter((n) => { + return ( + n.mode != LiteGraph.NEVER && + ((n.constructor as any).nodeData as ComfyObjectInfo)?.output_node + ); + }) || [] + ); +} + +function showQueueNodesMenuIfOutputNodesAreSelected(existingOptions: ContextMenuItem[]) { + if (CONFIG_SERVICE.getConfigValue("features.menu_queue_selected_nodes") === false) { + return; + } + const outputNodes = getOutputNodes(Object.values(app.canvas.selected_nodes)); + const menuItem = { + content: `Queue Selected Output Nodes (rgthree)  `, + className: "rgthree-contextmenu-item", + callback: () => { + rgthree.queueOutputNodes(outputNodes.map((n) => n.id)); + }, + disabled: !outputNodes.length, + }; + + let idx = existingOptions.findIndex((o) => o?.content === "Outputs") + 1; + idx = idx || existingOptions.findIndex((o) => o?.content === "Align") + 1; + idx = idx || 3; + existingOptions.splice(idx, 0, menuItem); +} + +function showQueueGroupNodesMenuIfGroupIsSelected(existingOptions: ContextMenuItem[]) { + if (CONFIG_SERVICE.getConfigValue("features.menu_queue_selected_nodes") === false) { + return; + } + const group = + rgthree.lastAdjustedMouseEvent && + app.graph.getGroupOnPos( + rgthree.lastAdjustedMouseEvent.canvasX, + rgthree.lastAdjustedMouseEvent.canvasY, + ); + + const outputNodes = group && getOutputNodes(group._nodes); + const menuItem = { + content: `Queue Group Output Nodes (rgthree)  `, + className: "rgthree-contextmenu-item", + callback: () => { + outputNodes && rgthree.queueOutputNodes(outputNodes.map((n) => n.id)); + }, + disabled: !outputNodes?.length, + }; + + let idx = existingOptions.findIndex((o) => o?.content?.startsWith("Queue Selected ")) + 1; + idx = idx || existingOptions.findIndex((o) => o?.content === "Outputs") + 1; + idx = idx || existingOptions.findIndex((o) => o?.content === "Align") + 1; + idx = idx || 3; + existingOptions.splice(idx, 0, menuItem); +} + +/** + * Adds a "Queue Node" menu item to all output nodes, working with `rgthree.queueOutputNode` to + * execute only a single node's path. + */ +app.registerExtension({ + name: "rgthree.QueueNode", + async beforeRegisterNodeDef(nodeType: ComfyNodeConstructor, nodeData: ComfyObjectInfo) { + const getExtraMenuOptions = nodeType.prototype.getExtraMenuOptions; + nodeType.prototype.getExtraMenuOptions = function ( + canvas: TLGraphCanvas, + options: ContextMenuItem[], + ) { + getExtraMenuOptions ? getExtraMenuOptions.apply(this, arguments) : undefined; + showQueueNodesMenuIfOutputNodesAreSelected(options); + showQueueGroupNodesMenuIfGroupIsSelected(options); + }; + }, + + async setup() { + const getCanvasMenuOptions = LGraphCanvas.prototype.getCanvasMenuOptions; + LGraphCanvas.prototype.getCanvasMenuOptions = function (...args: any[]) { + const options = getCanvasMenuOptions.apply(this, [...args] as any); + showQueueNodesMenuIfOutputNodesAreSelected(options); + showQueueGroupNodesMenuIfGroupIsSelected(options); + return options; + }; + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/muter.ts b/rgthree-comfy/src_web/comfyui/muter.ts new file mode 100644 index 0000000000000000000000000000000000000000..1741f3b2ba70c9ee3bf4391b3608d7ef8afbacbf --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/muter.ts @@ -0,0 +1,50 @@ +import { app } from "scripts/app.js"; +import { BaseNodeModeChanger } from "./base_node_mode_changer.js"; +import { NodeTypesString } from "./constants.js"; +import type { LGraphNode } from "typings/litegraph.js"; + +const MODE_MUTE = 2; +const MODE_ALWAYS = 0; + +class MuterNode extends BaseNodeModeChanger { + static override exposedActions = ["Mute all", "Enable all", "Toggle all"]; + + static override type = NodeTypesString.FAST_MUTER; + static override title = NodeTypesString.FAST_MUTER; + override comfyClass = NodeTypesString.FAST_MUTER; + override readonly modeOn = MODE_ALWAYS; + override readonly modeOff = MODE_MUTE; + + constructor(title = MuterNode.title) { + super(title); + this.onConstructed(); + } + + override async handleAction(action: string) { + if (action === "Mute all") { + for (const widget of this.widgets) { + this.forceWidgetOff(widget, true); + } + } else if (action === "Enable all") { + for (const widget of this.widgets) { + this.forceWidgetOn(widget, true); + } + } else if (action === "Toggle all") { + for (const widget of this.widgets) { + this.forceWidgetToggle(widget, true); + } + } + } +} + +app.registerExtension({ + name: "rgthree.Muter", + registerCustomNodes() { + MuterNode.setUp(); + }, + loadedGraphNode(node: LGraphNode) { + if (node.type == MuterNode.title) { + (node as any)._tempWidth = node.size[0]; + } + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/node_collector.ts b/rgthree-comfy/src_web/comfyui/node_collector.ts new file mode 100644 index 0000000000000000000000000000000000000000..796168b074d72411edb6de7e5b16733aa3d6d5a2 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/node_collector.ts @@ -0,0 +1,170 @@ +import { app } from "scripts/app.js"; +import type { + LLink, + LGraph, + ContextMenuItem, + LGraphCanvas, + SerializedLGraphNode, + LGraphNode as TLGraphNode, + IContextMenuOptions, + ContextMenu, +} from "typings/litegraph.js"; +import { addConnectionLayoutSupport } from "./utils.js"; +import { wait } from "rgthree/common/shared_utils.js"; +import { ComfyWidgets } from "scripts/widgets.js"; +import { BaseCollectorNode } from "./base_node_collector.js"; +import { NodeTypesString } from "./constants.js"; + +/** + * The Collector Node. Takes any number of inputs as connections for nodes and collects them into + * one outputs. The next node will decide what to do with them. + * + * Currently only works with the Fast Muter, Fast Bypasser, and Fast Actions Button. + */ +class CollectorNode extends BaseCollectorNode { + static override type = NodeTypesString.NODE_COLLECTOR; + static override title = NodeTypesString.NODE_COLLECTOR; + override comfyClass = NodeTypesString.NODE_COLLECTOR; + + constructor(title = CollectorNode.title) { + super(title); + this.onConstructed(); + } + + override onConstructed(): boolean { + this.addOutput("Output", "*"); + return super.onConstructed(); + } + + override configure(info: SerializedLGraphNode): void { + // Patch a small issue (~14h) where multiple OPT_CONNECTIONS may have been created. + // https://github.com/rgthree/rgthree-comfy/issues/206 + // TODO: This can probably be removed within a few weeks. + if (info.outputs?.length) { + info.outputs.length = 1; + } + super.configure(info); + } +} + +/** Legacy "Combiner" */ +class CombinerNode extends CollectorNode { + static legacyType = "Node Combiner (rgthree)"; + static override title = "‼️ Node Combiner [DEPRECATED]"; + + constructor(title = CombinerNode.title) { + super(title); + + const note = ComfyWidgets["STRING"]( + this, + "last_seed", + ["STRING", { multiline: true }], + app, + ).widget; + note.inputEl!.value = + 'The Node Combiner has been renamed to Node Collector. You can right-click and select "Update to Node Collector" to attempt to automatically update.'; + note.inputEl!.readOnly = true; + note.inputEl!.style.backgroundColor = "#332222"; + note.inputEl!.style.fontWeight = "bold"; + note.inputEl!.style.fontStyle = "italic"; + note.inputEl!.style.opacity = "0.8"; + + this.getExtraMenuOptions = (_: LGraphCanvas, options: ContextMenuItem[]) => { + options.splice(options.length - 1, 0, { + content: "‼️ Update to Node Collector", + callback: ( + _value: ContextMenuItem, + _options: IContextMenuOptions, + _event: MouseEvent, + _parentMenu: ContextMenu | undefined, + _node: TLGraphNode, + ) => { + updateCombinerToCollector(this); + }, + }); + }; + } + + override configure(info: SerializedLGraphNode) { + super.configure(info); + if (this.title != CombinerNode.title && !this.title.startsWith("‼️")) { + this.title = "‼️ " + this.title; + } + } +} + +/** + * Updates a Node Combiner to a Node Collector. + */ +async function updateCombinerToCollector(node: TLGraphNode) { + if (node.type === CombinerNode.legacyType) { + // Create a new CollectorNode. + const newNode = new CollectorNode(); + if (node.title != CombinerNode.title) { + newNode.title = node.title.replace("‼️ ", ""); + } + // Port the position, size, and properties from the old node. + newNode.pos = [...node.pos]; + newNode.size = [...node.size]; + newNode.properties = { ...node.properties }; + // We now collect the links data, inputs and outputs, of the old node since these will be + // lost when we remove it. + const links: any[] = []; + for (const [index, output] of node.outputs.entries()) { + for (const linkId of output.links || []) { + const link: LLink = (app.graph as LGraph).links[linkId]!; + if (!link) continue; + const targetNode = app.graph.getNodeById(link.target_id); + links.push({ node: newNode, slot: index, targetNode, targetSlot: link.target_slot }); + } + } + for (const [index, input] of node.inputs.entries()) { + const linkId = input.link; + if (linkId) { + const link: LLink = (app.graph as LGraph).links[linkId]!; + const originNode = app.graph.getNodeById(link.origin_id); + links.push({ + node: originNode, + slot: link.origin_slot, + targetNode: newNode, + targetSlot: index, + }); + } + } + // Add the new node, remove the old node. + app.graph.add(newNode); + await wait(); + // Now go through and connect the other nodes up as they were. + for (const link of links) { + link.node.connect(link.slot, link.targetNode, link.targetSlot); + } + await wait(); + app.graph.remove(node); + } +} + +app.registerExtension({ + name: "rgthree.NodeCollector", + registerCustomNodes() { + addConnectionLayoutSupport(CollectorNode, app, [ + ["Left", "Right"], + ["Right", "Left"], + ]); + + LiteGraph.registerNodeType(CollectorNode.title, CollectorNode); + CollectorNode.category = CollectorNode._category; + }, +}); + +app.registerExtension({ + name: "rgthree.NodeCombiner", + registerCustomNodes() { + addConnectionLayoutSupport(CombinerNode, app, [ + ["Left", "Right"], + ["Right", "Left"], + ]); + + LiteGraph.registerNodeType(CombinerNode.legacyType, CombinerNode); + CombinerNode.category = CombinerNode._category; + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/node_mode_relay.ts b/rgthree-comfy/src_web/comfyui/node_mode_relay.ts new file mode 100644 index 0000000000000000000000000000000000000000..1840bd09aaf559209d5e0f059767c24145ec445b --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/node_mode_relay.ts @@ -0,0 +1,287 @@ +import { app } from "scripts/app.js"; +import type { + INodeInputSlot, + INodeOutputSlot, + LGraphCanvas, + LGraphNode, + LLink, + SerializedLGraphNode, + Vector2, +} from "typings/litegraph.js"; +import type { NodeMode } from "typings/comfy.js"; +import { + PassThroughFollowing, + addConnectionLayoutSupport, + getConnectedInputNodesAndFilterPassThroughs, + getConnectedOutputNodesAndFilterPassThroughs, +} from "./utils.js"; +import { wait } from "rgthree/common/shared_utils.js"; +import { BaseCollectorNode } from "./base_node_collector.js"; +import { NodeTypesString, stripRgthree } from "./constants.js"; +import { fitString } from "./utils_canvas.js"; +import { rgthree } from "./rgthree.js"; + +const MODE_ALWAYS = 0; +const MODE_MUTE = 2; +const MODE_BYPASS = 4; +const MODE_REPEATS = [MODE_MUTE, MODE_BYPASS]; +const MODE_NOTHING = -99; // MADE THIS UP. + +const MODE_TO_OPTION = new Map([ + [MODE_ALWAYS, "ACTIVE"], + [MODE_MUTE, "MUTE"], + [MODE_BYPASS, "BYPASS"], + [MODE_NOTHING, "NOTHING"], +]); + +const OPTION_TO_MODE = new Map([ + ["ACTIVE", MODE_ALWAYS], + ["MUTE", MODE_MUTE], + ["BYPASS", MODE_BYPASS], + ["NOTHING", MODE_NOTHING], +]); + +const MODE_TO_PROPERTY = new Map([ + [MODE_MUTE, "on_muted_inputs"], + [MODE_BYPASS, "on_bypassed_inputs"], + [MODE_ALWAYS, "on_any_active_inputs"], +]); + +const logger = rgthree.newLogSession("[NodeModeRelay]"); + +/** + * Like a BaseCollectorNode, this relay node connects to a Repeater node and _relays_ mode changes + * changes to the repeater (so it can go on to modify its connections). + */ +class NodeModeRelay extends BaseCollectorNode { + override readonly inputsPassThroughFollowing: PassThroughFollowing = PassThroughFollowing.ALL; + + static override type = NodeTypesString.NODE_MODE_RELAY; + static override title = NodeTypesString.NODE_MODE_RELAY; + override comfyClass = NodeTypesString.NODE_MODE_RELAY; + + static "@on_muted_inputs" = { + type: "combo", + values: ["MUTE", "ACTIVE", "BYPASS", "NOTHING"], + }; + + static "@on_bypassed_inputs" = { + type: "combo", + values: ["BYPASS", "ACTIVE", "MUTE", "NOTHING"], + }; + + static "@on_any_active_inputs" = { + type: "combo", + values: ["BYPASS", "ACTIVE", "MUTE", "NOTHING"], + }; + + constructor(title?: string) { + super(title); + this.properties["on_muted_inputs"] = "MUTE"; + this.properties["on_bypassed_inputs"] = "BYPASS"; + this.properties["on_any_active_inputs"] = "ACTIVE"; + + this.onConstructed(); + } + + override onConstructed() { + this.addOutput("REPEATER", "_NODE_REPEATER_", { + color_on: "#Fc0", + color_off: "#a80", + shape: LiteGraph.ARROW_SHAPE, + }); + + setTimeout(() => { + this.stabilize(); + }, 500); + return super.onConstructed(); + } + + override onModeChange(from: NodeMode, to: NodeMode) { + super.onModeChange(from, to); + // If we aren't connected to anything, then we'll use our mode to relay when it changes. + if (this.inputs.length <= 1 && !this.isInputConnected(0) && this.isAnyOutputConnected()) { + const [n, v] = logger.infoParts(`Mode change without any inputs; relaying our mode.`); + console[n]?.(...v); + this.dispatchModeToRepeater(this.mode); + } + } + + override configure(info: SerializedLGraphNode): void { + // Patch a small issue (~14h) where multiple OPT_CONNECTIONS may have been created. + // https://github.com/rgthree/rgthree-comfy/issues/206 + // TODO: This can probably be removed within a few weeks. + if (info.outputs?.length) { + info.outputs.length = 1; + } + super.configure(info); + } + + override onDrawForeground(ctx: CanvasRenderingContext2D, canvas: LGraphCanvas): void { + if (this.flags?.collapsed) { + return; + } + if ( + this.properties["on_muted_inputs"] !== "MUTE" || + this.properties["on_bypassed_inputs"] !== "BYPASS" || + this.properties["on_any_active_inputs"] != "ACTIVE" + ) { + let margin = 15; + ctx.textAlign = "left"; + let label = `*(MUTE > ${this.properties["on_muted_inputs"]}, `; + label += `BYPASS > ${this.properties["on_bypassed_inputs"]}, `; + label += `ACTIVE > ${this.properties["on_any_active_inputs"]})`; + ctx.fillStyle = LiteGraph.WIDGET_SECONDARY_TEXT_COLOR; + const oldFont = ctx.font; + ctx.font = "italic " + (LiteGraph.NODE_SUBTEXT_SIZE - 2) + "px Arial"; + ctx.fillText(fitString(ctx, label, this.size[0] - 20), 15, this.size[1] - 6); + ctx.font = oldFont; + } + } + + override computeSize(out: Vector2) { + let size = super.computeSize(out); + if ( + this.properties["on_muted_inputs"] !== "MUTE" || + this.properties["on_bypassed_inputs"] !== "BYPASS" || + this.properties["on_any_active_inputs"] != "ACTIVE" + ) { + size[1] += 17; + } + return size; + } + override onConnectOutput( + outputIndex: number, + inputType: string | -1, + inputSlot: INodeInputSlot, + inputNode: LGraphNode, + inputIndex: number, + ): boolean { + let canConnect = super.onConnectOutput?.( + outputIndex, + inputType, + inputSlot, + inputNode, + inputIndex, + ); + let nextNode = getConnectedOutputNodesAndFilterPassThroughs(this, inputNode)[0] ?? inputNode; + return canConnect && nextNode.type === NodeTypesString.NODE_MODE_REPEATER; + } + + override onConnectionsChange( + type: number, + slotIndex: number, + isConnected: boolean, + link_info: LLink, + ioSlot: INodeOutputSlot | INodeInputSlot, + ): void { + super.onConnectionsChange(type, slotIndex, isConnected, link_info, ioSlot); + setTimeout(() => { + this.stabilize(); + }, 500); + } + + stabilize() { + // If we aren't connected to a repeater, then theres no sense in checking. And if we are, but + // have no inputs, then we're also not ready. + if (!this.graph || !this.isAnyOutputConnected() || !this.isInputConnected(0)) { + return; + } + const inputNodes = getConnectedInputNodesAndFilterPassThroughs( + this, + this, + -1, + this.inputsPassThroughFollowing, + ); + let mode: NodeMode | -99 | null = undefined; + for (const inputNode of inputNodes) { + // If we haven't set our mode to be, then let's set it. Otherwise, mode will stick if it + // remains constant, otherwise, if we hit an ALWAYS, then we'll unmute all repeaters and + // if not then we won't do anything. + if (mode === undefined) { + mode = inputNode.mode; + } else if (mode === inputNode.mode && MODE_REPEATS.includes(mode)) { + continue; + } else if (inputNode.mode === MODE_ALWAYS || mode === MODE_ALWAYS) { + mode = MODE_ALWAYS; + } else { + mode = null; + } + } + + this.dispatchModeToRepeater(mode); + setTimeout(() => { + this.stabilize(); + }, 500); + } + + /** + * Sends the mode to the repeater, checking to see if we're modifying our mode. + */ + private dispatchModeToRepeater(mode?: NodeMode | -99 | null) { + if (mode != null) { + const propertyVal = this.properties?.[MODE_TO_PROPERTY.get(mode) || ""]; + const newMode = OPTION_TO_MODE.get(propertyVal); + mode = (newMode !== null ? newMode : mode) as NodeMode | -99; + if (mode !== null && mode !== MODE_NOTHING) { + if (this.outputs?.length) { + const outputNodes = getConnectedOutputNodesAndFilterPassThroughs(this); + for (const outputNode of outputNodes) { + outputNode.mode = mode; + wait(16).then(() => { + outputNode.setDirtyCanvas(true, true); + }); + } + } + } + } + } + + override getHelp() { + return ` +

+ This node will relay its input nodes' modes (Mute, Bypass, or Active) to a connected + ${stripRgthree(NodeTypesString.NODE_MODE_REPEATER)} (which would then repeat that mode + change to all of its inputs). +

+
    +
  • + When all connected input nodes are muted, the relay will set a connected repeater to + mute (by default). +

  • +
  • + When all connected input nodes are bypassed, the relay will set a connected repeater to + bypass (by default). +

  • +
  • + When any connected input nodes are active, the relay will set a connected repeater to + active (by default). +

  • +
  • + If no inputs are connected, the relay will set a connected repeater to its mode when + its own mode is changed. Note, if any inputs are connected, then the above + will occur and the Relay's mode does not matter. +

  • +
+

+ Note, you can change which signals get sent on the above in the Properties. + For instance, you could configure an inverse relay which will send a MUTE when any of its + inputs are active (instead of sending an ACTIVE signal), and send an ACTIVE signal when all + of its inputs are muted (instead of sending a MUTE signal), etc. +

+ `; + } +} + +app.registerExtension({ + name: "rgthree.NodeModeRepeaterHelper", + registerCustomNodes() { + addConnectionLayoutSupport(NodeModeRelay, app, [ + ["Left", "Right"], + ["Right", "Left"], + ]); + + LiteGraph.registerNodeType(NodeModeRelay.type, NodeModeRelay); + NodeModeRelay.category = NodeModeRelay._category; + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/node_mode_repeater.ts b/rgthree-comfy/src_web/comfyui/node_mode_repeater.ts new file mode 100644 index 0000000000000000000000000000000000000000..fa127f5a98b0512f33efc1325e54dba47f86ae59 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/node_mode_repeater.ts @@ -0,0 +1,220 @@ +import { app } from "scripts/app.js"; +import { BaseCollectorNode } from "./base_node_collector.js"; +import { NodeTypesString, stripRgthree } from "./constants.js"; + +import type { + INodeInputSlot, + INodeOutputSlot, + LGraphGroup, + LGraphNode, + LLink, + SerializedLGraphNode, +} from "typings/litegraph.js"; +import { + PassThroughFollowing, + addConnectionLayoutSupport, + getConnectedInputNodesAndFilterPassThroughs, + getConnectedOutputNodesAndFilterPassThroughs, +} from "./utils.js"; +import { NodeMode } from "typings/comfy.js"; + +class NodeModeRepeater extends BaseCollectorNode { + override readonly inputsPassThroughFollowing: PassThroughFollowing = PassThroughFollowing.ALL; + + static override type = NodeTypesString.NODE_MODE_REPEATER; + static override title = NodeTypesString.NODE_MODE_REPEATER; + override comfyClass = NodeTypesString.NODE_MODE_REPEATER; + + private hasRelayInput = false; + private hasTogglerOutput = false; + + constructor(title?: string) { + super(title); + this.onConstructed(); + } + + override onConstructed(): boolean { + this.addOutput("OPT_CONNECTION", "*", { + color_on: "#Fc0", + color_off: "#a80", + }); + + return super.onConstructed(); + } + + override configure(info: SerializedLGraphNode): void { + // Patch a small issue (~14h) where multiple OPT_CONNECTIONS may have been created. + // https://github.com/rgthree/rgthree-comfy/issues/206 + // TODO: This can probably be removed within a few weeks. + if (info.outputs?.length) { + info.outputs.length = 1; + } + super.configure(info); + } + + override onConnectOutput( + outputIndex: number, + inputType: string | -1, + inputSlot: INodeInputSlot, + inputNode: LGraphNode, + inputIndex: number, + ): boolean { + // We can only connect to a a FAST_MUTER or FAST_BYPASSER if we aren't connectged to a relay, since the relay wins. + let canConnect = !this.hasRelayInput; + canConnect = + canConnect && super.onConnectOutput(outputIndex, inputType, inputSlot, inputNode, inputIndex); + // Output can only connect to a FAST MUTER, FAST BYPASSER, NODE_COLLECTOR OR ACTION BUTTON + let nextNode = getConnectedOutputNodesAndFilterPassThroughs(this, inputNode)[0] || inputNode; + return ( + canConnect && + [ + NodeTypesString.FAST_MUTER, + NodeTypesString.FAST_BYPASSER, + NodeTypesString.NODE_COLLECTOR, + NodeTypesString.FAST_ACTIONS_BUTTON, + NodeTypesString.REROUTE, + NodeTypesString.RANDOM_UNMUTER, + ].includes(nextNode.type || "") + ); + } + + override onConnectInput( + inputIndex: number, + outputType: string | -1, + outputSlot: INodeOutputSlot, + outputNode: LGraphNode, + outputIndex: number, + ): boolean { + // We can only connect to a a FAST_MUTER or FAST_BYPASSER if we aren't connectged to a relay, since the relay wins. + let canConnect = super.onConnectInput?.( + inputIndex, + outputType, + outputSlot, + outputNode, + outputIndex, + ); + // Output can only connect to a FAST MUTER or FAST BYPASSER + let nextNode = getConnectedOutputNodesAndFilterPassThroughs(this, outputNode)[0] || outputNode; + const isNextNodeRelay = nextNode.type === NodeTypesString.NODE_MODE_RELAY; + return canConnect && (!isNextNodeRelay || !this.hasTogglerOutput); + } + + override onConnectionsChange( + type: number, + slotIndex: number, + isConnected: boolean, + linkInfo: LLink, + ioSlot: INodeOutputSlot | INodeInputSlot, + ): void { + super.onConnectionsChange(type, slotIndex, isConnected, linkInfo, ioSlot); + + let hasTogglerOutput = false; + let hasRelayInput = false; + + const outputNodes = getConnectedOutputNodesAndFilterPassThroughs(this); + for (const outputNode of outputNodes) { + if ( + outputNode?.type === NodeTypesString.FAST_MUTER || + outputNode?.type === NodeTypesString.FAST_BYPASSER + ) { + hasTogglerOutput = true; + break; + } + } + + const inputNodes = getConnectedInputNodesAndFilterPassThroughs(this); + for (const [index, inputNode] of inputNodes.entries()) { + if (inputNode?.type === NodeTypesString.NODE_MODE_RELAY) { + // We can't be connected to a relay if we're connected to a toggler. Something has gone wrong. + if (hasTogglerOutput) { + console.log(`Can't be connected to a Relay if also output to a toggler.`); + this.disconnectInput(index); + } else { + hasRelayInput = true; + if (this.inputs[index]) { + this.inputs[index]!.color_on = "#FC0"; + this.inputs[index]!.color_off = "#a80"; + } + } + } else { + inputNode.mode = this.mode; + } + } + + this.hasTogglerOutput = hasTogglerOutput; + this.hasRelayInput = hasRelayInput; + + // If we have a relay input, then we should remove the toggler output, or add it if not. + if (this.hasRelayInput) { + if (this.outputs[0]) { + this.disconnectOutput(0); + this.removeOutput(0); + } + } else if (!this.outputs[0]) { + this.addOutput("OPT_CONNECTION", "*", { + color_on: "#Fc0", + color_off: "#a80", + }); + } + } + + /** When a mode change, we want all connected nodes to match except for connected relays. */ + override onModeChange(from: NodeMode, to: NodeMode) { + super.onModeChange(from, to); + const linkedNodes = getConnectedInputNodesAndFilterPassThroughs(this).filter( + (node) => node.type !== NodeTypesString.NODE_MODE_RELAY, + ); + if (linkedNodes.length) { + for (const node of linkedNodes) { + if (node.type !== NodeTypesString.NODE_MODE_RELAY) { + node.mode = this.mode; + } + } + } else if (app.graph._groups?.length) { + // No linked nodes.. check if we're in a group. + for (const group of app.graph._groups as LGraphGroup[]) { + group.recomputeInsideNodes(); + if (group._nodes?.includes(this)) { + for (const node of group._nodes) { + node.mode = this.mode; + } + } + } + } + } + + override getHelp(): string { + return ` +

+ When this node's mode (Mute, Bypass, Active) changes, it will "repeat" that mode to all + connected input nodes, or, if there are no connected nodes AND it is overlapping a group, + "repeat" it's mode to all nodes in that group. +

+
    +
  • + Optionally, connect this mode's output to a ${stripRgthree(NodeTypesString.FAST_MUTER)} + or ${stripRgthree(NodeTypesString.FAST_BYPASSER)} for a single toggle to quickly + mute/bypass all its connected nodes. +

  • +
  • + Optionally, connect a ${stripRgthree(NodeTypesString.NODE_MODE_RELAY)} to this nodes + inputs to have it automatically toggle its mode. If connected, this will always take + precedence (and disconnect any connected fast togglers). +

  • +
+ `; + } +} + +app.registerExtension({ + name: "rgthree.NodeModeRepeater", + registerCustomNodes() { + addConnectionLayoutSupport(NodeModeRepeater, app, [ + ["Left", "Right"], + ["Right", "Left"], + ]); + + LiteGraph.registerNodeType(NodeModeRepeater.type, NodeModeRepeater); + NodeModeRepeater.category = NodeModeRepeater._category; + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/power_lora_loader.ts b/rgthree-comfy/src_web/comfyui/power_lora_loader.ts new file mode 100644 index 0000000000000000000000000000000000000000..9b76142a4b3eac090e69698a9cdc91fe75da0ae1 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/power_lora_loader.ts @@ -0,0 +1,804 @@ +import { app } from "scripts/app.js"; +import type { + ContextMenuItem, + LGraphNode as TLGraphNode, + IWidget, + LGraphCanvas, + SerializedLGraphNode, + Vector2, + AdjustedMouseEvent, +} from "typings/litegraph.js"; +import type { ComfyObjectInfo, ComfyNodeConstructor } from "typings/comfy.js"; +import { RgthreeBaseServerNode } from "./base_node.js"; +import { rgthree } from "./rgthree.js"; +import { addConnectionLayoutSupport } from "./utils.js"; +import { NodeTypesString } from "./constants.js"; +import { + drawInfoIcon, + drawNumberWidgetPart, + drawRoundedRectangle, + drawTogglePart, + fitString, + isLowQuality, +} from "./utils_canvas.js"; +import { + RgthreeBaseHitAreas, + RgthreeBaseWidget, + RgthreeBetterButtonWidget, + RgthreeDividerWidget, +} from "./utils_widgets.js"; +import { rgthreeApi } from "rgthree/common/rgthree_api.js"; +import { showLoraChooser } from "./utils_menu.js"; +import { moveArrayItem, removeArrayItem } from "rgthree/common/shared_utils.js"; +import { RgthreeInfoDialog } from "./dialog_info.js"; +import type { RgthreeModelInfo } from "typings/rgthree.js"; +import { SERVICE as MODEL_INFO_SERVICE } from "rgthree/common/model_info_service.js"; +// import { RgthreePowerLoraChooserDialog } from "./dialog_power_lora_chooser.js"; + +const PROP_LABEL_SHOW_STRENGTHS = "Show Strengths"; +const PROP_LABEL_SHOW_STRENGTHS_STATIC = `@${PROP_LABEL_SHOW_STRENGTHS}`; +const PROP_VALUE_SHOW_STRENGTHS_SINGLE = "Single Strength"; +const PROP_VALUE_SHOW_STRENGTHS_SEPARATE = "Separate Model & Clip"; + +/** + * The Power Lora Loader is a super-simply Lora Loader node that can load multiple Loras at once + * in an ultra-condensed node allowing fast toggling, and advanced strength setting. + */ +class RgthreePowerLoraLoader extends RgthreeBaseServerNode { + static override title = NodeTypesString.POWER_LORA_LOADER; + static override type = NodeTypesString.POWER_LORA_LOADER; + static comfyClass = NodeTypesString.POWER_LORA_LOADER; + + override serialize_widgets = true; + + private logger = rgthree.newLogSession(`[Power Lora Stack]`); + + static [PROP_LABEL_SHOW_STRENGTHS_STATIC] = { + type: "combo", + values: [PROP_VALUE_SHOW_STRENGTHS_SINGLE, PROP_VALUE_SHOW_STRENGTHS_SEPARATE], + }; + + /** Counts the number of lora widgets. This is used to give unique names. */ + private loraWidgetsCounter = 0; + + /** Keep track of the spacer, new lora widgets will go before it when it exists. */ + private widgetButtonSpacer: IWidget | null = null; + + constructor(title = NODE_CLASS.title) { + super(title); + + this.properties[PROP_LABEL_SHOW_STRENGTHS] = PROP_VALUE_SHOW_STRENGTHS_SINGLE; + + // Prefetch loras list. + rgthreeApi.getLoras(); + } + + /** + * Handles configuration from a saved workflow by first removing our default widgets that were + * added in `onNodeCreated`, letting `super.configure` and do nothing, then create our lora + * widgets and, finally, add back in our default widgets. + */ + override configure(info: SerializedLGraphNode): void { + while (this.widgets?.length) this.removeWidget(0); + this.widgetButtonSpacer = null; + super.configure(info); + + (this as any)._tempWidth = this.size[0]; + (this as any)._tempHeight = this.size[1]; + for (const widgetValue of info.widgets_values || []) { + if (widgetValue?.lora !== undefined) { + const widget = this.addNewLoraWidget(); + widget.value = { ...widgetValue }; + } + } + this.addNonLoraWidgets(); + this.size[0] = (this as any)._tempWidth; + this.size[1] = Math.max((this as any)._tempHeight, this.computeSize()[1]); + } + + /** + * Adds the non-lora widgets. If we'll be configured then we remove them and add them back, so + * this is really only for newly created nodes in the current session. + */ + override onNodeCreated() { + super.onNodeCreated?.(); + this.addNonLoraWidgets(); + const computed = this.computeSize(); + this.size = this.size || [0, 0]; + this.size[0] = Math.max(this.size[0], computed[0]); + this.size[1] = Math.max(this.size[1], computed[1]); + this.setDirtyCanvas(true, true); + } + + /** Adds a new lora widget in the proper slot. */ + private addNewLoraWidget(lora?: string) { + this.loraWidgetsCounter++; + const widget = this.addCustomWidget( + new PowerLoraLoaderWidget("lora_" + this.loraWidgetsCounter), + ); + if (lora) widget.setLora(lora); + if (this.widgetButtonSpacer) { + moveArrayItem(this.widgets, widget, this.widgets.indexOf(this.widgetButtonSpacer)); + } + return widget; + } + + /** Adds the non-lora widgets around any lora ones that may be there from configuration. */ + private addNonLoraWidgets() { + moveArrayItem( + this.widgets, + this.addCustomWidget( + new RgthreeDividerWidget({ marginTop: 4, marginBottom: 0, thickness: 0 }), + ), + 0, + ); + moveArrayItem(this.widgets, this.addCustomWidget(new PowerLoraLoaderHeaderWidget()), 1); + + this.widgetButtonSpacer = this.addCustomWidget( + new RgthreeDividerWidget({ marginTop: 4, marginBottom: 0, thickness: 0 }), + ); + + this.addCustomWidget( + new RgthreeBetterButtonWidget( + "➕ Add Lora", + (event: AdjustedMouseEvent, pos: Vector2, node: TLGraphNode) => { + rgthreeApi.getLoras().then((loras) => { + showLoraChooser( + event as PointerEvent, + (value: ContextMenuItem | string) => { + if (typeof value === "string") { + if (value.includes("Power Lora Chooser")) { + // new RgthreePowerLoraChooserDialog().show(); + } else if (value !== "NONE") { + this.addNewLoraWidget(value); + const computed = this.computeSize(); + const tempHeight = (this as any)._tempHeight ?? 15; + this.size[1] = Math.max(tempHeight, computed[1]); + this.setDirtyCanvas(true, true); + } + } + // }, null, ["⚡️ Power Lora Chooser", ...loras]); + }, + null, + [...loras], + ); + }); + return true; + }, + ), + ); + } + + /** + * Hacks the `getSlotInPosition` call made from LiteGraph so we can show a custom context menu + * for widgets. + * + * Normally this method, called from LiteGraph's processContextMenu, will only get Inputs or + * Outputs. But that's not good enough because we we also want to provide a custom menu when + * clicking a widget for this node... so we are left to HACK once again! + * + * To achieve this: + * - Here, in LiteGraph's processContextMenu it asks the clicked node to tell it which input or + * output the user clicked on in `getSlotInPosition` + * - We check, and if we didn't, then we see if we clicked a widget and, if so, pass back some + * data that looks like we clicked an output to fool LiteGraph like a silly child. + * - As LiteGraph continues in its `processContextMenu`, it will then immediately call + * the clicked node's `getSlotMenuOptions` when `getSlotInPosition` returns data. + * - So, just below, we can then give LiteGraph the ContextMenu options we have. + * + * The only issue is that LiteGraph also checks `input/output.type` to set the ContextMenu title, + * so we need to supply that property (and set it to what we want our title). Otherwise, this + * should be pretty clean. + */ + override getSlotInPosition(canvasX: number, canvasY: number): any { + const slot = super.getSlotInPosition(canvasX, canvasY); + // No slot, let's see if it's a widget. + if (!slot) { + let lastWidget = null; + for (const widget of this.widgets) { + // If last_y isn't set, something is wrong. Bail. + if (!widget.last_y) return; + if (canvasY > this.pos[1] + widget.last_y) { + lastWidget = widget; + continue; + } + break; + } + // Only care about lora widget clicks. + if (lastWidget?.name?.startsWith("lora_")) { + return { widget: lastWidget, output: { type: "LORA WIDGET" } }; + } + } + return slot; + } + + /** + * Working with the overridden `getSlotInPosition` above, this method checks if the passed in + * option is actually a widget from it and then hijacks the context menu all together. + */ + override getSlotMenuOptions(slot: any): ContextMenuItem[] | null { + // Oddly, LiteGraph doesn't call back into our node with a custom menu (even though it let's us + // define a custom menu to begin with... wtf?). So, we'll return null so the default is not + // triggered and then we'll just show one ourselves because.. yea. + if (slot?.widget?.name?.startsWith("lora_")) { + const widget = slot.widget as PowerLoraLoaderWidget; + const index = this.widgets.indexOf(widget); + const canMoveUp = !!this.widgets[index - 1]?.name?.startsWith("lora_"); + const canMoveDown = !!this.widgets[index + 1]?.name?.startsWith("lora_"); + const menuItems: ContextMenuItem[] = [ + { + content: `ℹ️ Show Info`, + callback: () => { + widget.showLoraInfoDialog(); + }, + }, + null, // Divider + { + content: `${widget.value.on ? "⚫" : "🟢"} Toggle ${widget.value.on ? "Off" : "On"}`, + callback: () => { + widget.value.on = !widget.value.on; + }, + }, + { + content: `⬆️ Move Up`, + disabled: !canMoveUp, + callback: () => { + moveArrayItem(this.widgets, widget, index - 1); + }, + }, + { + content: `⬇️ Move Down`, + disabled: !canMoveDown, + callback: () => { + moveArrayItem(this.widgets, widget, index + 1); + }, + }, + { + content: `🗑️ Remove`, + callback: () => { + removeArrayItem(this.widgets, widget); + }, + }, + ]; + + let canvas = app.canvas as LGraphCanvas; + new LiteGraph.ContextMenu( + menuItems, + { title: "LORA WIDGET", event: rgthree.lastAdjustedMouseEvent! }, + canvas.getCanvasWindow(), + ); + + return null; + } + return this.defaultGetSlotMenuOptions(slot); + } + + /** + * When `refreshComboInNode` is called from ComfyUI, then we'll kick off a fresh loras fetch. + */ + refreshComboInNode(defs: any) { + rgthreeApi.getLoras(true); + } + + /** + * Returns true if there are any Lora Widgets. Useful for widgets to ask as they render. + */ + hasLoraWidgets() { + return !!this.widgets?.find((w) => w.name?.startsWith("lora_")); + } + + /** + * This will return true when all lora widgets are on, false when all are off, or null if it's + * mixed. + */ + allLorasState() { + let allOn = true; + let allOff = true; + for (const widget of this.widgets) { + if (widget.name?.startsWith("lora_")) { + const on = widget.value?.on; + allOn = allOn && on === true; + allOff = allOff && on === false; + if (!allOn && !allOff) { + return null; + } + } + } + return allOn && this.widgets?.length ? true : false; + } + + /** + * Toggles all the loras on or off. + */ + toggleAllLoras() { + const allOn = this.allLorasState(); + const toggledTo = !allOn ? true : false; + for (const widget of this.widgets) { + if (widget.name?.startsWith("lora_")) { + widget.value.on = toggledTo; + } + } + } + + static override setUp(comfyClass: ComfyNodeConstructor, nodeData: ComfyObjectInfo) { + RgthreeBaseServerNode.registerForOverride(comfyClass, nodeData, NODE_CLASS); + } + + static override onRegisteredForOverride(comfyClass: any, ctxClass: any) { + addConnectionLayoutSupport(NODE_CLASS, app, [ + ["Left", "Right"], + ["Right", "Left"], + ]); + setTimeout(() => { + NODE_CLASS.category = comfyClass.category; + }); + } + + override getHelp() { + return ` +

+ The ${this.type!.replace("(rgthree)", "")} is a powerful node that condenses 100s of pixels + of functionality in a single, dynamic node that allows you to add loras, change strengths, + and quickly toggle on/off all without taking up half your screen. +

+
    +
  • + Add as many Lora's as you would like by clicking the "+ Add Lora" button. + There's no real limit! +

  • +
  • + Right-click on a Lora widget for special options to move the lora up or down + (no image affect, only presentational), toggle it on/off, or delete the row all together. +

  • +
  • +

    + Properties. You can change the following properties (by right-clicking + on the node, and select "Properties" or "Properties Panel" from the menu): +

    +
      +
    • + ${PROP_LABEL_SHOW_STRENGTHS} - Change between showing a single, simple + strength (which will be used for both model and clip), or a more advanced view with + both model and clip strengths being modifiable. +

    • +
    +
  • +
`; + } +} + +/** + * The PowerLoraLoaderHeaderWidget that renders a toggle all switch, as well as some title info + * (more necessary for the double model & clip strengths to label them). + */ +class PowerLoraLoaderHeaderWidget extends RgthreeBaseWidget<{ type: string }> { + private showModelAndClip: boolean | null = null; + + value = { type: "PowerLoraLoaderHeaderWidget" }; + + protected override hitAreas: RgthreeBaseHitAreas<"toggle"> = { + toggle: { bounds: [0, 0] as Vector2, onDown: this.onToggleDown }, + }; + + constructor(name: string = "PowerLoraLoaderHeaderWidget") { + super(name); + } + + draw( + ctx: CanvasRenderingContext2D, + node: RgthreePowerLoraLoader, + w: number, + posY: number, + height: number, + ) { + if (!node.hasLoraWidgets()) { + return; + } + // Since draw is the loop that runs, this is where we'll check the property state (rather than + // expect the node to tell us it's state etc). + this.showModelAndClip = + node.properties[PROP_LABEL_SHOW_STRENGTHS] === PROP_VALUE_SHOW_STRENGTHS_SEPARATE; + const margin = 10; + const innerMargin = margin * 0.33; + const lowQuality = isLowQuality(); + const allLoraState = node.allLorasState(); + + // Move slightly down. We don't have a border and this feels a bit nicer. + posY += 2; + const midY = posY + height * 0.5; + let posX = 10; + ctx.save(); + this.hitAreas.toggle.bounds = drawTogglePart(ctx, { posX, posY, height, value: allLoraState }); + + if (!lowQuality) { + posX += this.hitAreas.toggle.bounds[1] + innerMargin; + + ctx.globalAlpha = app.canvas.editor_alpha * 0.55; + ctx.fillStyle = LiteGraph.WIDGET_TEXT_COLOR; + ctx.textAlign = "left"; + ctx.textBaseline = "middle"; + ctx.fillText("Toggle All", posX, midY); + + let rposX = node.size[0] - margin - innerMargin - innerMargin; + ctx.textAlign = "center"; + ctx.fillText( + this.showModelAndClip ? "Clip" : "Strength", + rposX - drawNumberWidgetPart.WIDTH_TOTAL / 2, + midY, + ); + if (this.showModelAndClip) { + rposX = rposX - drawNumberWidgetPart.WIDTH_TOTAL - innerMargin * 2; + ctx.fillText("Model", rposX - drawNumberWidgetPart.WIDTH_TOTAL / 2, midY); + } + } + ctx.restore(); + } + + /** + * Handles a pointer down on the toggle's defined hit area. + */ + onToggleDown(event: AdjustedMouseEvent, pos: Vector2, node: TLGraphNode) { + (node as RgthreePowerLoraLoader).toggleAllLoras(); + this.cancelMouseDown(); + return true; + } +} + +const DEFAULT_LORA_WIDGET_DATA: PowerLoraLoaderWidgetValue = { + on: true, + lora: null as string | null, + strength: 1, + strengthTwo: null as number | null, +}; + +type PowerLoraLoaderWidgetValue = { + on: boolean; + lora: string | null; + strength: number; + strengthTwo: number | null; +}; + +/** + * The PowerLoaderWidget that combines several custom drawing and functionality in a single row. + */ +class PowerLoraLoaderWidget extends RgthreeBaseWidget { + /** Whether the strength has changed with mouse move (to cancel mouse up). */ + private haveMouseMovedStrength = false; + private loraInfoPromise: Promise | null = null; + private loraInfo: RgthreeModelInfo | null = null; + + private showModelAndClip: boolean | null = null; + + protected override hitAreas: RgthreeBaseHitAreas< + | "toggle" + | "lora" + // | "info" + | "strengthDec" + | "strengthVal" + | "strengthInc" + | "strengthAny" + | "strengthTwoDec" + | "strengthTwoVal" + | "strengthTwoInc" + | "strengthTwoAny" + > = { + toggle: { bounds: [0, 0] as Vector2, onDown: this.onToggleDown }, + lora: { bounds: [0, 0] as Vector2, onDown: this.onLoraDown }, + // info: { bounds: [0, 0] as Vector2, onDown: this.onInfoDown }, + + strengthDec: { bounds: [0, 0] as Vector2, onDown: this.onStrengthDecDown }, + strengthVal: { bounds: [0, 0] as Vector2, onUp: this.onStrengthValUp }, + strengthInc: { bounds: [0, 0] as Vector2, onDown: this.onStrengthIncDown }, + strengthAny: { bounds: [0, 0] as Vector2, onMove: this.onStrengthAnyMove }, + + strengthTwoDec: { bounds: [0, 0] as Vector2, onDown: this.onStrengthTwoDecDown }, + strengthTwoVal: { bounds: [0, 0] as Vector2, onUp: this.onStrengthTwoValUp }, + strengthTwoInc: { bounds: [0, 0] as Vector2, onDown: this.onStrengthTwoIncDown }, + strengthTwoAny: { bounds: [0, 0] as Vector2, onMove: this.onStrengthTwoAnyMove }, + }; + + constructor(name: string) { + super(name); + } + + private _value = { + on: true, + lora: null as string | null, + strength: 1, + strengthTwo: null as number | null, + }; + + set value(v) { + this._value = v; + // In case widgets are messed up, we can correct course here. + if (typeof this._value !== "object") { + this._value = { ...DEFAULT_LORA_WIDGET_DATA }; + if (this.showModelAndClip) { + this._value.strengthTwo = this._value.strength; + } + } + this.getLoraInfo(); + } + + get value() { + return this._value; + } + + setLora(lora: string) { + this._value.lora = lora; + this.getLoraInfo(); + } + + /** Draws our widget with a toggle, lora selector, and number selector all in a single row. */ + draw(ctx: CanvasRenderingContext2D, node: TLGraphNode, w: number, posY: number, height: number) { + // Since draw is the loop that runs, this is where we'll check the property state (rather than + // expect the node to tell us it's state etc). + let currentShowModelAndClip = + node.properties[PROP_LABEL_SHOW_STRENGTHS] === PROP_VALUE_SHOW_STRENGTHS_SEPARATE; + if (this.showModelAndClip !== currentShowModelAndClip) { + let oldShowModelAndClip = this.showModelAndClip; + this.showModelAndClip = currentShowModelAndClip; + if (this.showModelAndClip) { + // If we're setting show both AND we're not null, then re-set to the current strength. + if (oldShowModelAndClip != null) { + this.value.strengthTwo = this.value.strength ?? 1; + } + } else { + this.value.strengthTwo = null; + this.hitAreas.strengthTwoDec.bounds = [0, -1]; + this.hitAreas.strengthTwoVal.bounds = [0, -1]; + this.hitAreas.strengthTwoInc.bounds = [0, -1]; + this.hitAreas.strengthTwoAny.bounds = [0, -1]; + } + } + + ctx.save(); + const margin = 10; + const innerMargin = margin * 0.33; + const lowQuality = isLowQuality(); + const midY = posY + height * 0.5; + + // We'll move posX along as we draw things. + let posX = margin; + + // Draw the background. + drawRoundedRectangle(ctx, { posX, posY, height, width: node.size[0] - margin * 2 }); + + // Draw the toggle + this.hitAreas.toggle.bounds = drawTogglePart(ctx, { posX, posY, height, value: this.value.on }); + posX += this.hitAreas.toggle.bounds[1] + innerMargin; + + // If low quality, then we're done rendering. + if (lowQuality) { + ctx.restore(); + return; + } + + // If we're not toggled on, then make everything after faded. + if (!this.value.on) { + ctx.globalAlpha = app.canvas.editor_alpha * 0.4; + } + + ctx.fillStyle = LiteGraph.WIDGET_TEXT_COLOR; + + // Now, we draw the strength number part on the right, so we know the width of it to draw the + // lora label as flexible. + let rposX = node.size[0] - margin - innerMargin - innerMargin; + + const strengthValue = this.showModelAndClip + ? this.value.strengthTwo ?? 1 + : this.value.strength ?? 1; + + let textColor: string | undefined = undefined; + if (this.loraInfo?.strengthMax != null && strengthValue > this.loraInfo?.strengthMax) { + textColor = "#c66"; + } else if (this.loraInfo?.strengthMin != null && strengthValue < this.loraInfo?.strengthMin) { + textColor = "#c66"; + } + + const [leftArrow, text, rightArrow] = drawNumberWidgetPart(ctx, { + posX: node.size[0] - margin - innerMargin - innerMargin, + posY, + height, + value: strengthValue, + direction: -1, + textColor, + }); + + this.hitAreas.strengthDec.bounds = leftArrow; + this.hitAreas.strengthVal.bounds = text; + this.hitAreas.strengthInc.bounds = rightArrow; + this.hitAreas.strengthAny.bounds = [leftArrow[0], rightArrow[0] + rightArrow[1] - leftArrow[0]]; + + rposX = leftArrow[0] - innerMargin; + + if (this.showModelAndClip) { + rposX -= innerMargin; + // If we're showing both, then the rightmost we just drew is our "strengthTwo", so reset and + // then draw our model ("strength" one) to the left. + this.hitAreas.strengthTwoDec.bounds = this.hitAreas.strengthDec.bounds; + this.hitAreas.strengthTwoVal.bounds = this.hitAreas.strengthVal.bounds; + this.hitAreas.strengthTwoInc.bounds = this.hitAreas.strengthInc.bounds; + this.hitAreas.strengthTwoAny.bounds = this.hitAreas.strengthAny.bounds; + + let textColor: string | undefined = undefined; + if (this.loraInfo?.strengthMax != null && this.value.strength > this.loraInfo?.strengthMax) { + textColor = "#c66"; + } else if ( + this.loraInfo?.strengthMin != null && + this.value.strength < this.loraInfo?.strengthMin + ) { + textColor = "#c66"; + } + const [leftArrow, text, rightArrow] = drawNumberWidgetPart(ctx, { + posX: rposX, + posY, + height, + value: this.value.strength ?? 1, + direction: -1, + textColor, + }); + this.hitAreas.strengthDec.bounds = leftArrow; + this.hitAreas.strengthVal.bounds = text; + this.hitAreas.strengthInc.bounds = rightArrow; + this.hitAreas.strengthAny.bounds = [ + leftArrow[0], + rightArrow[0] + rightArrow[1] - leftArrow[0], + ]; + rposX = leftArrow[0] - innerMargin; + } + + const infoIconSize = height * 0.66; + const infoWidth = infoIconSize + innerMargin + innerMargin; + // Draw an info emoji; if checks if it's enabled (to quickly turn it on or off) + if ((this.hitAreas as any)["info"]) { + rposX -= innerMargin; + drawInfoIcon(ctx, rposX - infoIconSize, posY + (height - infoIconSize) / 2, infoIconSize); + // ctx.fillText('ℹ', posX, midY); + (this.hitAreas as any).info.bounds = [rposX - infoIconSize, infoWidth]; + rposX = rposX - infoIconSize - innerMargin; + } + + // Draw lora label + const loraWidth = rposX - posX; + ctx.textAlign = "left"; + ctx.textBaseline = "middle"; + const loraLabel = String(this.value?.lora || "None"); + ctx.fillText(fitString(ctx, loraLabel, loraWidth), posX, midY); + + this.hitAreas.lora.bounds = [posX, loraWidth]; + posX += loraWidth + innerMargin; + + ctx.globalAlpha = app.canvas.editor_alpha; + ctx.restore(); + } + + serializeValue(serializedNode: SerializedLGraphNode, widgetIndex: number) { + const v = { ...this.value }; + // Never send the second value to the backend if we're not showing it, otherwise, let's just + // make sure it's not null. + if (!this.showModelAndClip) { + delete (v as any).strengthTwo; + } else { + this.value.strengthTwo = this.value.strengthTwo ?? 1; + v.strengthTwo = this.value.strengthTwo; + } + return v; + } + + onToggleDown(event: AdjustedMouseEvent, pos: Vector2, node: TLGraphNode) { + this.value.on = !this.value.on; + this.cancelMouseDown(); // Clear the down since we handle it. + return true; + } + + onInfoDown(event: AdjustedMouseEvent, pos: Vector2, node: TLGraphNode) { + this.showLoraInfoDialog(); + } + + onLoraDown(event: AdjustedMouseEvent, pos: Vector2, node: TLGraphNode) { + showLoraChooser(event, (value: ContextMenuItem) => { + if (typeof value === "string") { + this.value.lora = value; + this.loraInfo = null; + this.getLoraInfo(); + } + node.setDirtyCanvas(true, true); + }); + this.cancelMouseDown(); + } + + onStrengthDecDown(event: AdjustedMouseEvent, pos: Vector2, node: TLGraphNode) { + this.stepStrength(-1, false); + } + onStrengthIncDown(event: AdjustedMouseEvent, pos: Vector2, node: TLGraphNode) { + this.stepStrength(1, false); + } + onStrengthTwoDecDown(event: AdjustedMouseEvent, pos: Vector2, node: TLGraphNode) { + this.stepStrength(-1, true); + } + onStrengthTwoIncDown(event: AdjustedMouseEvent, pos: Vector2, node: TLGraphNode) { + this.stepStrength(1, true); + } + + onStrengthAnyMove(event: AdjustedMouseEvent, pos: Vector2, node: TLGraphNode) { + this.doOnStrengthAnyMove(event, false); + } + + onStrengthTwoAnyMove(event: AdjustedMouseEvent, pos: Vector2, node: TLGraphNode) { + this.doOnStrengthAnyMove(event, true); + } + + private doOnStrengthAnyMove(event: AdjustedMouseEvent, isTwo = false) { + if (event.deltaX) { + let prop: "strengthTwo" | "strength" = isTwo ? "strengthTwo" : "strength"; + this.haveMouseMovedStrength = true; + this.value[prop] = (this.value[prop] ?? 1) + event.deltaX * 0.05; + } + } + + onStrengthValUp(event: AdjustedMouseEvent, pos: Vector2, node: TLGraphNode) { + this.doOnStrengthValUp(event, false); + } + + onStrengthTwoValUp(event: AdjustedMouseEvent, pos: Vector2, node: TLGraphNode) { + this.doOnStrengthValUp(event, true); + } + + private doOnStrengthValUp(event: AdjustedMouseEvent, isTwo = false) { + if (this.haveMouseMovedStrength) return; + let prop: "strengthTwo" | "strength" = isTwo ? "strengthTwo" : "strength"; + const canvas = app.canvas as LGraphCanvas; + canvas.prompt("Value", this.value[prop], (v: string) => (this.value[prop] = Number(v)), event); + } + + override onMouseUp(event: AdjustedMouseEvent, pos: Vector2, node: TLGraphNode): boolean | void { + super.onMouseUp(event, pos, node); + this.haveMouseMovedStrength = false; + } + + showLoraInfoDialog() { + if (!this.value.lora || this.value.lora === "None") { + return; + } + const infoDialog = new RgthreeInfoDialog(this.value.lora).show(); + infoDialog.addEventListener("close", ((e: CustomEvent<{ dirty: boolean }>) => { + if (e.detail.dirty) { + this.getLoraInfo(true); + } + }) as EventListener); + } + + private stepStrength(direction: -1 | 1, isTwo = false) { + let step = 0.05; + let prop: "strengthTwo" | "strength" = isTwo ? "strengthTwo" : "strength"; + let strength = (this.value[prop] ?? 1) + step * direction; + this.value[prop] = Math.round(strength * 100) / 100; + } + + private getLoraInfo(force = false) { + if (!this.loraInfoPromise || force == true) { + let promise; + if (this.value.lora && this.value.lora != "None") { + promise = MODEL_INFO_SERVICE.getLora(this.value.lora, force, true); + } else { + promise = Promise.resolve(null); + } + this.loraInfoPromise = promise.then((v) => (this.loraInfo = v)); + } + return this.loraInfoPromise; + } +} + +/** An uniformed name reference to the node class. */ +const NODE_CLASS = RgthreePowerLoraLoader; + +/** Register the node. */ +app.registerExtension({ + name: "rgthree.PowerLoraLoader", + async beforeRegisterNodeDef(nodeType: ComfyNodeConstructor, nodeData: ComfyObjectInfo) { + if (nodeData.name === NODE_CLASS.type) { + NODE_CLASS.setUp(nodeType, nodeData); + } + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/power_prompt.ts b/rgthree-comfy/src_web/comfyui/power_prompt.ts new file mode 100644 index 0000000000000000000000000000000000000000..930261136f76d2827cafa3432c5db564118e842a --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/power_prompt.ts @@ -0,0 +1,56 @@ +import { app } from "scripts/app.js"; +import type { LGraphNode as TLGraphNode } from "typings/litegraph.js"; +import type { + ComfyApp, + ComfyObjectInfo, + ComfyGraphNode, + ComfyNodeConstructor, +} from "typings/comfy.js"; +import { addConnectionLayoutSupport } from "./utils.js"; +import { PowerPrompt } from "./base_power_prompt.js"; +import { NodeTypesString } from "./constants.js"; + +let nodeData: ComfyObjectInfo | null = null; +app.registerExtension({ + name: "rgthree.PowerPrompt", + async beforeRegisterNodeDef( + nodeType: ComfyNodeConstructor, + passedNodeData: ComfyObjectInfo, + _app: ComfyApp, + ) { + if (passedNodeData.name.includes("Power Prompt") && passedNodeData.name.includes("rgthree")) { + nodeData = passedNodeData; + const onNodeCreated = nodeType.prototype.onNodeCreated; + nodeType.prototype.onNodeCreated = function () { + onNodeCreated ? onNodeCreated.apply(this, []) : undefined; + (this as any).powerPrompt = new PowerPrompt(this as ComfyGraphNode, passedNodeData); + }; + addConnectionLayoutSupport(nodeType, app, [ + ["Left", "Right"], + ["Right", "Left"], + ]); + } + }, + async loadedGraphNode(node: TLGraphNode) { + if (node.type === NodeTypesString.POWER_PROMPT) { + setTimeout(() => { + // If the first output is STRING, then it's the text output from the initial launch. + // Let's port it to the new + if (node.outputs[0]!.type === "STRING") { + if (node.outputs[0]!.links) { + node.outputs[3]!.links = node.outputs[3]!.links || []; + for (const link of node.outputs[0]!.links) { + node.outputs[3]!.links.push(link); + app.graph.links[link]!.origin_slot = 3; + } + node.outputs[0]!.links = null; + } + node.outputs[0]!.type = nodeData!.output![0] as string; + node.outputs[0]!.name = nodeData!.output_name![0] || (node.outputs[0]!.type as string); + node.outputs[0]!.color_on = undefined; + node.outputs[0]!.color_off = undefined; + } + }, 50); + } + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/random_unmuter.ts b/rgthree-comfy/src_web/comfyui/random_unmuter.ts new file mode 100644 index 0000000000000000000000000000000000000000..7f45c79f4693293b2e9803354d60322f30e8a844 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/random_unmuter.ts @@ -0,0 +1,118 @@ +import type { LGraphNode } from "typings/litegraph.js"; +import type { RgthreeBaseVirtualNodeConstructor } from "typings/rgthree.js"; + +import { app } from "scripts/app.js"; +import { BaseAnyInputConnectedNode } from "./base_any_input_connected_node.js"; +import { NodeTypesString } from "./constants.js"; +import { rgthree } from "./rgthree.js"; +import { getConnectedInputNodesAndFilterPassThroughs } from "./utils.js"; + +const MODE_MUTE = 2; +const MODE_ALWAYS = 0; + +class RandomUnmuterNode extends BaseAnyInputConnectedNode { + static override exposedActions = ["Mute all", "Enable all"]; + + static override type = NodeTypesString.RANDOM_UNMUTER; + override comfyClass = NodeTypesString.RANDOM_UNMUTER; + static override title = RandomUnmuterNode.type; + readonly modeOn = MODE_ALWAYS; + readonly modeOff = MODE_MUTE; + + tempEnabledNode: LGraphNode | null = null; + processingQueue: boolean = false; + + onQueueBound = this.onQueue.bind(this); + onQueueEndBound = this.onQueueEnd.bind(this); + onGraphtoPromptBound = this.onGraphtoPrompt.bind(this); + onGraphtoPromptEndBound = this.onGraphtoPromptEnd.bind(this); + + constructor(title = RandomUnmuterNode.title) { + super(title); + + rgthree.addEventListener("queue", this.onQueueBound); + rgthree.addEventListener("queue-end", this.onQueueEndBound); + rgthree.addEventListener("graph-to-prompt", this.onGraphtoPromptBound); + rgthree.addEventListener("graph-to-prompt-end", this.onGraphtoPromptEndBound); + this.onConstructed(); + } + + override onRemoved() { + rgthree.removeEventListener("queue", this.onQueueBound); + rgthree.removeEventListener("queue-end", this.onQueueEndBound); + rgthree.removeEventListener("graph-to-prompt", this.onGraphtoPromptBound); + rgthree.removeEventListener("graph-to-prompt-end", this.onGraphtoPromptEndBound); + } + + onQueue(event: Event) { + this.processingQueue = true; + } + onQueueEnd(event: Event) { + this.processingQueue = false; + } + onGraphtoPrompt(event: Event) { + if (!this.processingQueue) { + return; + } + this.tempEnabledNode = null; + // Check that all are muted and, if so, choose one to unmute. + const linkedNodes = getConnectedInputNodesAndFilterPassThroughs(this); + let allMuted = true; + if (linkedNodes.length) { + for (const node of linkedNodes) { + if (node.mode !== this.modeOff) { + allMuted = false; + break; + } + } + if (allMuted) { + this.tempEnabledNode = linkedNodes[Math.floor(Math.random() * linkedNodes.length)] || null; + if (this.tempEnabledNode) { + this.tempEnabledNode.mode = this.modeOn; + } + } + } + } + onGraphtoPromptEnd(event: Event) { + if (this.tempEnabledNode) { + this.tempEnabledNode.mode = this.modeOff; + this.tempEnabledNode = null; + } + } + + override handleLinkedNodesStabilization(linkedNodes: LGraphNode[]): void { + // No-op, no widgets. + } + + override getHelp(): string { + return ` +

+ Use this node to unmute on of its inputs randomly when the graph is queued (and, immediately + mute it back). +

+
    +
  • + NOTE: All input nodes MUST be muted to start; if not this node will not randomly unmute + another. (This is powerful, as the generated image can be dragged in and the chosen input + will already by unmuted and work w/o any further action.) +

  • +
  • + TIP: Connect a Repeater's output to this nodes input and place that Repeater on a group + without any other inputs, and it will mute/unmute the entire group. +

  • +
+ `; + } +} + +app.registerExtension({ + name: "rgthree.RandomUnmuter", + registerCustomNodes() { + RandomUnmuterNode.setUp(); + }, + loadedGraphNode(node: LGraphNode) { + if (node.type == RandomUnmuterNode.title) { + (node as any)._tempWidth = node.size[0]; + } + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/reroute.ts b/rgthree-comfy/src_web/comfyui/reroute.ts new file mode 100644 index 0000000000000000000000000000000000000000..d656c04ed43f35c40a1baab43a622905317f5489 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/reroute.ts @@ -0,0 +1,1258 @@ +import { app } from "scripts/app.js"; +import { + getWidgetConfig, + mergeIfValid, + setWidgetConfig, + // @ts-ignore +} from "../../extensions/core/widgetInputs.js"; +// @ts-ignore +import { rgthreeConfig } from "rgthree/config.js"; +import { rgthree } from "./rgthree.js"; +import type { + Vector2, + LLink, + LGraphCanvas as TLGraphCanvas, + LGraph as TLGraph, + SerializedLGraphNode, + INodeInputSlot, + INodeOutputSlot, + LGraphNode as TLGraphNode, +} from "typings/litegraph.js"; +import { + IoDirection, + LAYOUT_CLOCKWISE, + LAYOUT_LABEL_OPPOSITES, + LAYOUT_LABEL_TO_DATA, + addConnectionLayoutSupport, + addMenuItem, + getSlotLinks, + isValidConnection, + setConnectionsLayout, + waitForCanvas, +} from "./utils.js"; +import { SERVICE as KEY_EVENT_SERVICE } from "./services/key_events_services.js"; +import { wait } from "rgthree/common/shared_utils.js"; +import { RgthreeBaseVirtualNode } from "./base_node.js"; +import { NodeTypesString } from "./constants.js"; + +const CONFIG_REROUTE = rgthreeConfig?.["nodes"]?.["reroute"] || {}; + +const CONFIG_FAST_REROUTE = CONFIG_REROUTE["fast_reroute"]; +const CONFIG_FAST_REROUTE_ENABLED = CONFIG_FAST_REROUTE["enabled"] ?? false; +const CONFIG_KEY_CREATE_WHILE_LINKING = CONFIG_FAST_REROUTE["key_create_while_dragging_link"]; +const CONFIG_KEY_ROTATE = CONFIG_FAST_REROUTE["key_rotate"]; +const CONFIG_KEY_RESIZE = CONFIG_FAST_REROUTE["key_resize"]; +const CONFIG_KEY_MOVE = CONFIG_FAST_REROUTE["key_move"]; +const CONFIG_KEY_CXN_INPUT = CONFIG_FAST_REROUTE["key_connections_input"]; +const CONFIG_KEY_CXN_OUTPUT = CONFIG_FAST_REROUTE["key_connections_output"]; + +let configWidth = Math.max( + Math.round((Number(CONFIG_REROUTE["default_width"]) || 40) / 10) * 10, + 10, +); +let configHeight = Math.max( + Math.round((Number(CONFIG_REROUTE["default_height"]) || 30) / 10) * 10, + 10, +); +// Don't allow too small sizes. Granted, 400 is too small, but at least you can right click and +// resize... 10x10 you cannot. +while (configWidth * configHeight < 400) { + configWidth += 10; + configHeight += 10; +} +const configDefaultSize = [configWidth, configHeight] as Vector2; +const configResizable = !!CONFIG_REROUTE["default_resizable"]; +let configLayout: [string, string] = CONFIG_REROUTE["default_layout"]; +if (!Array.isArray(configLayout)) { + configLayout = ["Left", "Right"]; +} +if (!LAYOUT_LABEL_TO_DATA[configLayout[0]]) { + configLayout[0] = "Left"; +} +if (!LAYOUT_LABEL_TO_DATA[configLayout[1]] || configLayout[0] == configLayout[1]) { + configLayout[1] = LAYOUT_LABEL_OPPOSITES[configLayout[0]]!; +} + +type FastRerouteEntryCtx = { + node: TLGraphNode; + input?: INodeInputSlot; + output?: INodeOutputSlot; + slot: number; + pos: Vector2; +}; + +type FastRerouteEntry = { + node: RerouteNode; + previous: FastRerouteEntryCtx; + current?: FastRerouteEntryCtx; +}; + +/** + * RerouteService handles any coordination between reroute nodes and the system. Mostly, it's for + * fast-rerouting that can create a new reroute nodes while dragging a link. + */ +class RerouteService { + private isFastLinking = false; + private handledNewRerouteKeypress = false; + private connectingData: FastRerouteEntryCtx|null = null; + private fastReroutesHistory: FastRerouteEntry[] = []; + + private handleLinkingKeydownBound = this.handleLinkingKeydown.bind(this); + private handleLinkingKeyupBound = this.handleLinkingKeyup.bind(this); + + constructor() { + if (CONFIG_FAST_REROUTE_ENABLED && CONFIG_KEY_CREATE_WHILE_LINKING?.trim()) { + this.onCanvasSetUpListenerForLinking(); + } + } + + /** + * Waits for canvas to be available, then sets up a property accessor for `connecting_node` so + * we can start/stop monitoring for shortcut keys. + */ + async onCanvasSetUpListenerForLinking() { + const canvas = await waitForCanvas(); + + // With the new UI released in August 2024, ComfyUI changed LiteGraph's code, removing + // connecting_node, connecting_node, connecting_node, and connecting_node properties and instead + // using an array of connecting_links. We can try to accomodate both for a while. + const canvasProperty = true ? 'connecting_links' : 'connecting_node'; + (canvas as any)[`_${canvasProperty}`]; + const thisService = this; + Object.defineProperty(canvas, canvasProperty, { + get: function () { + return this[`_${canvasProperty}`]; + }, + set: function (value) { + const isValNull = !value || !value?.length; + const isPropNull = !this[`_${canvasProperty}`] || !this[`_${canvasProperty}`]?.length; + const isStartingLinking = !isValNull && isPropNull; + const isStoppingLinking = !isPropNull && isValNull; + this[`_${canvasProperty}`] = value; + if (isStartingLinking) { + thisService.startingLinking(); + } + if (isStoppingLinking) { + thisService.stoppingLinking(); + thisService.connectingData = null; + } + }, + }); + } + + /** + * When the user is actively dragging a link, listens for keydown events so we can enable + * shortcuts. + * + * Is only accessible if both `CONFIG_FAST_REROUTE_ENABLED` is true, and + * CONFIG_KEY_CREATE_WHILE_LINKING is not falsy/empty. + */ + private startingLinking() { + this.isFastLinking = true; + KEY_EVENT_SERVICE.addEventListener("keydown", this.handleLinkingKeydownBound as EventListener); + KEY_EVENT_SERVICE.addEventListener("keyup", this.handleLinkingKeyupBound as EventListener); + } + + /** + * When the user stops actively dragging a link, cleans up motnioring data and events. + * + * Is only accessible if both `CONFIG_FAST_REROUTE_ENABLED` is true, and + * CONFIG_KEY_CREATE_WHILE_LINKING is not falsy/empty. + */ + private stoppingLinking() { + this.isFastLinking = false; + this.fastReroutesHistory = []; + KEY_EVENT_SERVICE.removeEventListener("keydown", this.handleLinkingKeydownBound as EventListener); + KEY_EVENT_SERVICE.removeEventListener("keyup", this.handleLinkingKeyupBound as EventListener); + } + + /** + * Handles the keydown event. + * + * Is only accessible if both `CONFIG_FAST_REROUTE_ENABLED` is true, and + * CONFIG_KEY_CREATE_WHILE_LINKING is not falsy/empty. + */ + private handleLinkingKeydown(event: KeyboardEvent) { + if ( + !this.handledNewRerouteKeypress && + KEY_EVENT_SERVICE.areOnlyKeysDown(CONFIG_KEY_CREATE_WHILE_LINKING) + ) { + this.handledNewRerouteKeypress = true; + this.insertNewRerouteWhileLinking(); + } + } + + /** + * Handles the keyup event. + * + * Is only accessible if both `CONFIG_FAST_REROUTE_ENABLED` is true, and + * CONFIG_KEY_CREATE_WHILE_LINKING is not falsy/empty. + */ + private handleLinkingKeyup(event: KeyboardEvent) { + if ( + this.handledNewRerouteKeypress && + !KEY_EVENT_SERVICE.areOnlyKeysDown(CONFIG_KEY_CREATE_WHILE_LINKING) + ) { + this.handledNewRerouteKeypress = false; + } + } + + private getConnectingData() : FastRerouteEntryCtx { + const oldCanvas = app.canvas as any; + if (oldCanvas.connecting_node && oldCanvas.connecting_slot != null && oldCanvas.connecting_pos?.length) { + return { + node: oldCanvas.connecting_node, + input: oldCanvas.connecting_input, + output: oldCanvas.connecting_output, + slot: oldCanvas.connecting_slot, + pos: [...oldCanvas.connecting_pos] as Vector2, + }; + } + const canvas = app.canvas; + if (canvas.connecting_links?.length) { + // Assume just the first. + const link = canvas.connecting_links[0]!; + return { + node: link.node, + input: link.input, + output: link.output, + slot: link.slot, + pos: [...link.pos], + }; + } + throw new Error("Error, handling linking keydown, but there's no link."); + } + + private setCanvasConnectingData(ctx: FastRerouteEntryCtx) { + const oldCanvas = app.canvas as any; + if (oldCanvas.connecting_node && oldCanvas.connecting_slot != null && oldCanvas.connecting_pos?.length) { + oldCanvas.connecting_node = ctx.node; + oldCanvas.connecting_input = ctx.input; + oldCanvas.connecting_output = ctx.output; + oldCanvas.connecting_slot = ctx.slot; + oldCanvas.connecting_pos = ctx.pos; + } + const canvas = app.canvas; + if (canvas.connecting_links?.length) { + // Assume just the first. + const link = canvas.connecting_links[0]!; + link.node = ctx.node; + link.input = ctx.input; + link.output = ctx.output; + link.slot = ctx.slot; + link.pos = ctx.pos; + } + } + + /** + * Inserts a new reroute (while linking) as called from key down handler. + * + * Is only accessible if both `CONFIG_FAST_REROUTE_ENABLED` is true, and + * CONFIG_KEY_CREATE_WHILE_LINKING is not falsy/empty. + */ + private insertNewRerouteWhileLinking() { + const canvas = app.canvas; + this.connectingData = this.getConnectingData(); + if (!this.connectingData) { + throw new Error("Error, handling linking keydown, but there's no link."); + } + + const data = this.connectingData; + const node = LiteGraph.createNode("Reroute (rgthree)") as RerouteNode; + const entry: FastRerouteEntry = { + node, + previous: {...this.connectingData}, + current: undefined, + }; + this.fastReroutesHistory.push(entry); + + let connectingDir = (data.input || data.output)?.dir; + if (!connectingDir) { + connectingDir = data.input ? LiteGraph.LEFT : LiteGraph.RIGHT; + } + + let newPos = canvas.convertEventToCanvasOffset({ + clientX: Math.round(canvas.last_mouse_position[0] / 10) * 10, + clientY: Math.round(canvas.last_mouse_position[1] / 10) * 10, + }); + entry.node.pos = newPos; + canvas.graph.add(entry.node); + canvas.selectNode(entry.node); + + // Find out which direction we're generally moving. + const distX = entry.node.pos[0] - data.pos[0]; + const distY = entry.node.pos[1] - data.pos[1]; + + const layout: [string, string] = ["Left", "Right"]; + if (distX > 0 && Math.abs(distX) > Math.abs(distY)) { + // To the right, and further right than up or down. + layout[0] = data.output ? "Left" : "Right"; + layout[1] = LAYOUT_LABEL_OPPOSITES[layout[0]]!; + node.pos[0] -= node.size[0] + 10; + node.pos[1] -= Math.round(node.size[1] / 2 / 10) * 10; + } else if (distX < 0 && Math.abs(distX) > Math.abs(distY)) { + // To the left, and further right than up or down. + layout[0] = data.output ? "Right" : "Left"; + layout[1] = LAYOUT_LABEL_OPPOSITES[layout[0]]!; + node.pos[1] -= Math.round(node.size[1] / 2 / 10) * 10; + } else if (distY < 0 && Math.abs(distY) > Math.abs(distX)) { + // Above and further above than left or right. + layout[0] = data.output ? "Bottom" : "Top"; + layout[1] = LAYOUT_LABEL_OPPOSITES[layout[0]]!; + node.pos[0] -= Math.round(node.size[0] / 2 / 10) * 10; + } else if (distY > 0 && Math.abs(distY) > Math.abs(distX)) { + // Below and further below than left or right. + layout[0] = data.output ? "Top" : "Bottom"; + layout[1] = LAYOUT_LABEL_OPPOSITES[layout[0]]!; + node.pos[0] -= Math.round(node.size[0] / 2 / 10) * 10; + node.pos[1] -= node.size[1] + 10; + } + setConnectionsLayout(entry.node, layout); + + if (data.output) { + data.node.connect(data.slot, entry.node, 0); + data.node = entry.node; + data.output = entry.node.outputs[0]!; + data.slot = 0; + data.pos = entry.node.getConnectionPos(false, 0); + } else { + entry.node.connect(0, data.node, data.slot); + data.node = entry.node; + data.input = entry.node.inputs[0]!; + data.slot = 0; + data.pos = entry.node.getConnectionPos(true, 0); + } + this.setCanvasConnectingData(data); + entry.current = {...this.connectingData}; + + app.graph.setDirtyCanvas(true, true); + } + + /** + * Is called from a reroute node when it is resized or moved so the service can check if we're + * actively linking to it, and it can update the linking data so the connection moves too, by + * updating `connecting_pos`. + */ + handleMoveOrResizeNodeMaybeWhileDragging(node: RerouteNode) { + const data = this.connectingData!; + if (this.isFastLinking && node === data?.node) { + const entry = this.fastReroutesHistory[this.fastReroutesHistory.length - 1]; + if (entry) { + data.pos = entry.node.getConnectionPos(!!data.input, 0); + this.setCanvasConnectingData(data); + } + } + } + + /** + * Is called from a reroute node when it is deleted so the service can check if we're actively + * linking to it and go "back" in history to the previous node. + */ + handleRemovedNodeMaybeWhileDragging(node: RerouteNode) { + const currentEntry = this.fastReroutesHistory[this.fastReroutesHistory.length - 1]; + if (currentEntry?.node === node) { + this.setCanvasConnectingData(currentEntry.previous); + this.fastReroutesHistory.splice(this.fastReroutesHistory.length - 1, 1); + if (currentEntry.previous.node) { + app.canvas.selectNode(currentEntry.previous.node); + } + } + } +} + +const SERVICE = new RerouteService(); + +/** + * The famous ReroutNode, that has true multidirectional, expansive sizes, etc. + */ +class RerouteNode extends RgthreeBaseVirtualNode { + static override title = NodeTypesString.REROUTE; + static override type = NodeTypesString.REROUTE; + override comfyClass = NodeTypesString.REROUTE; + + static readonly title_mode = LiteGraph.NO_TITLE; + + static collapsable = false; + static layout_slot_offset = 5; + static size: Vector2 = configDefaultSize; // Starting size, read from within litegraph.core + + override isVirtualNode = true; + readonly hideSlotLabels = true; + + private schedulePromise: Promise | null = null; + + defaultConnectionsLayout = Array.from(configLayout); + + /** Shortcuts defined in the config. */ + private shortcuts = { + rotate: { keys: CONFIG_KEY_ROTATE, state: false }, + connection_input: { keys: CONFIG_KEY_CXN_INPUT, state: false }, + connection_output: { keys: CONFIG_KEY_CXN_OUTPUT, state: false }, + resize: { + keys: CONFIG_KEY_RESIZE, + state: false, + initialMousePos: [-1, -1] as Vector2, + initialNodeSize: [-1, -1] as Vector2, + initialNodePos: [-1, -1] as Vector2, + resizeOnSide: [-1, -1] as Vector2, + }, + move: { + keys: CONFIG_KEY_MOVE, + state: false, + initialMousePos: [-1, -1] as Vector2, + initialNodePos: [-1, -1] as Vector2, + }, + }; + + constructor(title = RerouteNode.title) { + super(title); + this.onConstructed(); + } + + override onConstructed(): boolean { + this.setResizable(this.properties["resizable"] ?? configResizable); + this.size = RerouteNode.size; // Starting size. + this.addInput("", "*"); + this.addOutput("", "*"); + setTimeout(() => this.applyNodeSize(), 20); + return super.onConstructed(); + } + + override configure(info: SerializedLGraphNode) { + // Patch a small issue (~14h) where multiple OPT_CONNECTIONS may have been created. + // https://github.com/rgthree/rgthree-comfy/issues/206 + // TODO: This can probably be removed within a few weeks. + if (info.outputs?.length) { + info.outputs.length = 1; + } + if (info.inputs?.length) { + info.inputs.length = 1; + } + super.configure(info); + this.configuring = true; + this.setResizable(this.properties["resizable"] ?? configResizable); + this.applyNodeSize(); + this.configuring = false; + } + + setResizable(resizable: boolean) { + this.properties["resizable"] = !!resizable; + this.resizable = this.properties["resizable"]; + } + + override clone() { + const cloned = super.clone(); + cloned.inputs[0]!.type = "*"; + cloned.outputs[0]!.type = "*"; + return cloned; + } + + /** + * Copied a good bunch of this from the original reroute included with comfy. + */ + override onConnectionsChange( + type: number, + _slotIndex: number, + connected: boolean, + _link_info: LLink, + _ioSlot: INodeOutputSlot | INodeInputSlot, + ) { + // Prevent multiple connections to different types when we have no input + if (connected && type === LiteGraph.OUTPUT) { + // Ignore wildcard nodes as these will be updated to real types + const types = new Set( + this.outputs[0]!.links!.map((l) => app.graph.links[l]!.type).filter((t) => t !== "*"), + ); + if (types.size > 1) { + const linksToDisconnect = []; + for (let i = 0; i < this.outputs[0]!.links!.length - 1; i++) { + const linkId = this.outputs[0]!.links![i]!; + const link = app.graph.links[linkId]; + linksToDisconnect.push(link); + } + for (const link of linksToDisconnect) { + const node = app.graph.getNodeById(link!.target_id)!; + node.disconnectInput(link!.target_slot); + } + } + } + this.scheduleStabilize(); + } + + override onDrawForeground(ctx: CanvasRenderingContext2D, canvas: TLGraphCanvas): void { + if (this.properties?.["showLabel"]) { + // ComfyUI seemed to break us again, but couldn't repro. No reason to not check, I guess. + // https://github.com/rgthree/rgthree-comfy/issues/71 + const low_quality = canvas?.ds?.scale && canvas.ds.scale < 0.6; + if (low_quality || this.size[0] <= 10) { + return; + } + const fontSize = Math.min(14, (this.size[1] * 0.65) | 0); + ctx.save(); + ctx.fillStyle = "#888"; + ctx.font = `${fontSize}px Arial`; + ctx.textAlign = "center"; + ctx.textBaseline = "middle"; + ctx.fillText( + String( + this.title && this.title !== RerouteNode.title + ? this.title + : this.outputs?.[0]?.type || "", + ), + this.size[0] / 2, + this.size[1] / 2, + this.size[0] - 30, + ); + ctx.restore(); + } + } + + /** Finds the input slot; since we only ever have one, this is always 0. */ + override findInputSlot(name: string): number { + return 0; + } + + /** Finds the output slot; since we only ever have one, this is always 0. */ + override findOutputSlot(name: string): number { + return 0; + } + + override disconnectOutput(slot: string | number, targetNode?: TLGraphNode | undefined): boolean { + return super.disconnectOutput(slot, targetNode); + } + + override disconnectInput(slot: string | number): boolean { + // [🤮] ComfyUI's reroute nodes will disconnect our input if it doesn't yet match (ours being + // "*" and it's being a type. This mostly happens if we're converting reroutes to rgthree + // reroutes, the old reroute does a check and calls disconnectInput. Luckily, we can be smarter + // and check if we're being asked to disconnect from an old reroute while we're replacing + // reroute nodes (via rgthree.replacingReroute state). + if (rgthree.replacingReroute != null && this.inputs[0]?.link) { + const graph = app.graph as TLGraph; + const link = graph.links[this.inputs[0].link]; + const node = graph.getNodeById(link?.origin_id); + // We'll also be asked to disconnect when the old one is removed, so we only want to stop a + // disconnect when the connected node is NOT the one being removed/replaced. + if (rgthree.replacingReroute !== node?.id) { + return false; + } + } + return super.disconnectInput(slot); + } + + scheduleStabilize(ms = 64) { + if (!this.schedulePromise) { + this.schedulePromise = new Promise((resolve) => { + setTimeout(() => { + this.schedulePromise = null; + this.stabilize(); + resolve(); + }, ms); + }); + } + return this.schedulePromise; + } + + stabilize() { + // If we are currently "configuring" then skip this stabilization. The connected nodes may + // not yet be configured. + if (this.configuring) { + return; + } + // Find root input + let currentNode: TLGraphNode | null = this; + let updateNodes = []; + let input = null; + let inputType = null; + let inputNode = null; + let inputNodeOutputSlot = null; + while (currentNode) { + updateNodes.unshift(currentNode); + const linkId: number | null = currentNode.inputs[0]!.link; + if (linkId !== null) { + const link: LLink = (app.graph as TLGraph).links[linkId]!; + const node: TLGraphNode = (app.graph as TLGraph).getNodeById(link.origin_id)!; + if (!node) { + // Bummer, somthing happened.. should we cleanup? + app.graph.removeLink(linkId); + currentNode = null; + break; + } + const type = (node.constructor as typeof TLGraphNode).type; + if (type?.includes("Reroute")) { + if (node === this) { + // We've found a circle + currentNode.disconnectInput(link.target_slot); + currentNode = null; + } else { + // Move the previous node + currentNode = node; + } + } else { + // We've found the end + inputNode = node; + inputNodeOutputSlot = link.origin_slot; + input = node.outputs[inputNodeOutputSlot] ?? null; + inputType = input?.type ?? null; + break; + } + } else { + // This path has no input node + currentNode = null; + break; + } + } + + // Find all outputs + const nodes: TLGraphNode[] = [this]; + let outputNode = null; + let outputType = null; + // For primitive nodes, which look at the widget to dsplay themselves. + let outputWidgetConfig = null; + let outputWidget = null; + while (nodes.length) { + currentNode = nodes.pop()!; + const outputs = (currentNode.outputs ? currentNode.outputs[0]!.links : []) || []; + if (outputs.length) { + for (const linkId of outputs) { + const link = app.graph.links[linkId]; + + // When disconnecting sometimes the link is still registered + if (!link) continue; + + const node = app.graph.getNodeById(link.target_id) as TLGraphNode; + // Don't know why this ever happens.. but it did around the repeater.. + if (!node) continue; + const type = (node.constructor as any).type; + if (type?.includes("Reroute")) { + // Follow reroute nodes + nodes.push(node); + updateNodes.push(node); + } else { + // We've found an output + const output = node.inputs?.[link.target_slot] ?? null; + const nodeOutType = output?.type; + if (nodeOutType == null) { + console.warn( + `[rgthree] Reroute - Connected node ${node.id} does not have type information for ` + + `slot ${link.target_slot}. Skipping connection enforcement, but something is odd ` + + `with that node.`, + ); + } else if ( + inputType && + inputType !== "*" && + nodeOutType !== "*" && + !isValidConnection(input, output) + ) { + // The output doesnt match our input so disconnect it + console.warn( + `[rgthree] Reroute - Disconnecting connected node's input (${node.id}.${ + link.target_slot + }) (${node.type}) because its type (${String( + nodeOutType, + )}) does not match the reroute type (${String(inputType)})`, + ); + node.disconnectInput(link.target_slot); + } else { + outputType = nodeOutType; + outputNode = node; + outputWidgetConfig = null; + outputWidget = null; + // For primitive nodes, which look at the widget to dsplay themselves. + if (output?.widget) { + try { + const config = getWidgetConfig(output); + if (!outputWidgetConfig && config) { + outputWidgetConfig = config[1] ?? {}; + outputType = config[0]; + if (!outputWidget) { + outputWidget = outputNode.widgets?.find( + (w) => w.name === output?.widget?.name, + ); + } + const merged = mergeIfValid(output, [config[0], outputWidgetConfig]); + if (merged.customConfig) { + outputWidgetConfig = merged.customConfig; + } + } + } catch (e) { + // Something happened, probably because comfyUI changes their methods. + console.error( + "[rgthree] Could not propagate widget infor for reroute; maybe ComfyUI updated?", + ); + outputWidgetConfig = null; + outputWidget = null; + } + } + } + } + } + } else { + // No more outputs for this path + } + } + + const displayType = inputType || outputType || "*"; + const color = LGraphCanvas.link_type_colors[displayType]; + + // Update the types of each node + for (const node of updateNodes) { + // If we dont have an input type we are always wildcard but we'll show the output type + // This lets you change the output link to a different type and all nodes will update + node.outputs[0]!.type = inputType || "*"; + (node as any).__outputType = displayType; + node.outputs[0]!.name = input?.name || ""; + node.size = node.computeSize(); + (node as any).applyNodeSize?.(); + + for (const l of node.outputs[0]!.links || []) { + const link = app.graph.links[l]; + if (link) { + link.color = color; + } + } + + try { + // For primitive nodes, which look at the widget to dsplay themselves. + if (outputWidgetConfig && outputWidget && outputType) { + node.inputs[0]!.widget = { name: "value" }; + setWidgetConfig( + node.inputs[0], + [outputType ?? displayType, outputWidgetConfig], + outputWidget, + ); + } else { + setWidgetConfig(node.inputs[0], null); + } + } catch (e) { + // Something happened, probably because comfyUI changes their methods. + console.error("[rgthree] Could not set widget config for reroute; maybe ComfyUI updated?"); + outputWidgetConfig = null; + outputWidget = null; + if (node.inputs[0]?.widget) { + delete node.inputs[0].widget; + } + } + } + + if (inputNode && inputNodeOutputSlot != null) { + const links = inputNode.outputs[inputNodeOutputSlot]!.links; + for (const l of links || []) { + const link = app.graph.links[l]; + if (link) { + link.color = color; + } + } + } + (inputNode as any)?.onConnectionsChainChange?.(); + (outputNode as any)?.onConnectionsChainChange?.(); + app.graph.setDirtyCanvas(true, true); + } + + /** + * When called, sets the node size, and the properties size, and calls out to `stabilizeLayout`. + */ + override setSize(size: Vector2): void { + const oldSize: Vector2 = [...this.size]; + const newSize: Vector2 = [...size]; + super.setSize(newSize); + this.properties["size"] = [...this.size]; + this.stabilizeLayout(oldSize, newSize); + } + + /** + * Looks at the current layout and determins if we also need to set a `connections_dir` based on + * the size of the node (and what that connections_dir should be). + */ + private stabilizeLayout(oldSize: Vector2, newSize: Vector2) { + if (newSize[0] === 10 || newSize[1] === 10) { + const props = this.properties; + props["connections_layout"] = props["connections_layout"] || ["Left", "Right"]; + const layout = props["connections_layout"]; + props["connections_dir"] = props["connections_dir"] || [-1, -1]; + const dir = props["connections_dir"]; + + if (oldSize[0] > 10 && newSize[0] === 10) { + dir[0] = LiteGraph.DOWN; + dir[1] = LiteGraph.UP; + if (layout[0] === "Bottom") { + layout[1] = "Top"; + } else if (layout[1] === "Top") { + layout[0] = "Bottom"; + } else { + layout[0] = "Top"; + layout[1] = "Bottom"; + dir[0] = LiteGraph.UP; + dir[1] = LiteGraph.DOWN; + } + this.setDirtyCanvas(true, true); + } else if (oldSize[1] > 10 && newSize[1] === 10) { + dir[0] = LiteGraph.RIGHT; + dir[1] = LiteGraph.LEFT; + if (layout[0] === "Right") { + layout[1] = "Left"; + } else if (layout[1] === "Left") { + layout[0] = "Right"; + } else { + layout[0] = "Left"; + layout[1] = "Right"; + dir[0] = LiteGraph.LEFT; + dir[1] = LiteGraph.RIGHT; + } + this.setDirtyCanvas(true, true); + } + } + SERVICE.handleMoveOrResizeNodeMaybeWhileDragging(this); + } + + applyNodeSize() { + this.properties["size"] = this.properties["size"] || RerouteNode.size; + this.properties["size"] = [ + Number(this.properties["size"][0]), + Number(this.properties["size"][1]), + ]; + this.size = this.properties["size"]; + app.graph.setDirtyCanvas(true, true); + } + + /** + * Rotates the node, including changing size and moving input's and output's layouts. + */ + rotate(degrees: 90 | -90 | 180) { + const w = this.size[0]; + const h = this.size[1]; + this.properties["connections_layout"] = + this.properties["connections_layout"] || (this as RerouteNode).defaultConnectionsLayout; + const inputDirIndex = LAYOUT_CLOCKWISE.indexOf(this.properties["connections_layout"][0]); + const outputDirIndex = LAYOUT_CLOCKWISE.indexOf(this.properties["connections_layout"][1]); + if (degrees == 90 || degrees === -90) { + if (degrees === -90) { + this.properties["connections_layout"][0] = + LAYOUT_CLOCKWISE[(((inputDirIndex - 1) % 4) + 4) % 4]; + this.properties["connections_layout"][1] = + LAYOUT_CLOCKWISE[(((outputDirIndex - 1) % 4) + 4) % 4]; + } else { + this.properties["connections_layout"][0] = + LAYOUT_CLOCKWISE[(((inputDirIndex + 1) % 4) + 4) % 4]; + this.properties["connections_layout"][1] = + LAYOUT_CLOCKWISE[(((outputDirIndex + 1) % 4) + 4) % 4]; + } + } else if (degrees === 180) { + this.properties["connections_layout"][0] = + LAYOUT_CLOCKWISE[(((inputDirIndex + 2) % 4) + 4) % 4]; + this.properties["connections_layout"][1] = + LAYOUT_CLOCKWISE[(((outputDirIndex + 2) % 4) + 4) % 4]; + } + this.setSize([h, w]); + } + + /** + * Manually handles a move called from `onMouseMove` while the resize shortcut is active. + */ + private manuallyHandleMove(event: PointerEvent) { + const shortcut = this.shortcuts.move; + if (shortcut.state) { + const diffX = Math.round((event.clientX - shortcut.initialMousePos[0]) / 10) * 10; + const diffY = Math.round((event.clientY - shortcut.initialMousePos[1]) / 10) * 10; + this.pos[0] = shortcut.initialNodePos[0] + diffX; + this.pos[1] = shortcut.initialNodePos[1] + diffY; + this.setDirtyCanvas(true, true); + SERVICE.handleMoveOrResizeNodeMaybeWhileDragging(this); + } + } + + /** + * Manually handles a resize called from `onMouseMove` while the resize shortcut is active. + */ + private manuallyHandleResize(event: PointerEvent) { + const shortcut = this.shortcuts.resize; + if (shortcut.state) { + let diffX = Math.round((event.clientX - shortcut.initialMousePos[0]) / 10) * 10; + let diffY = Math.round((event.clientY - shortcut.initialMousePos[1]) / 10) * 10; + diffX *= shortcut.resizeOnSide[0] === LiteGraph.LEFT ? -1 : 1; + diffY *= shortcut.resizeOnSide[1] === LiteGraph.UP ? -1 : 1; + const oldSize: Vector2 = [...this.size]; + this.setSize([ + Math.max(10, shortcut.initialNodeSize[0] + diffX), + Math.max(10, shortcut.initialNodeSize[1] + diffY), + ]); + if (shortcut.resizeOnSide[0] === LiteGraph.LEFT && oldSize[0] > 10) { + this.pos[0] = shortcut.initialNodePos[0] - diffX; + } + if (shortcut.resizeOnSide[1] === LiteGraph.UP && oldSize[1] > 10) { + this.pos[1] = shortcut.initialNodePos[1] - diffY; + } + this.setDirtyCanvas(true, true); + } + } + + /** + * Cycles the connection (input or output) to the next available layout. Note, when the width or + * height is only 10px, then layout sticks to the ends of the longer size, and we move a + * `connections_dir` property which is only paid attention to in `utils` when size of one axis + * is equal to 10. + * `manuallyHandleResize` handles the reset of `connections_dir` when a node is resized. + */ + private cycleConnection(ioDir: IoDirection) { + const props = this.properties; + props["connections_layout"] = props["connections_layout"] || ["Left", "Right"]; + const propIdx = ioDir == IoDirection.INPUT ? 0 : 1; + const oppositeIdx = propIdx ? 0 : 1; + let currentLayout = props["connections_layout"][propIdx]; + let oppositeLayout = props["connections_layout"][oppositeIdx]; + + if (this.size[0] === 10 || this.size[1] === 10) { + props["connections_dir"] = props["connections_dir"] || [-1, -1]; + let currentDir = props["connections_dir"][propIdx] as number; + // let oppositeDir = props["connections_dir"][oppositeIdx]; + const options: number[] = + this.size[0] === 10 + ? currentLayout === "Bottom" + ? [LiteGraph.DOWN, LiteGraph.RIGHT, LiteGraph.LEFT] + : [LiteGraph.UP, LiteGraph.LEFT, LiteGraph.RIGHT] + : currentLayout === "Right" + ? [LiteGraph.RIGHT, LiteGraph.DOWN, LiteGraph.UP] + : [LiteGraph.LEFT, LiteGraph.UP, LiteGraph.DOWN]; + let idx = options.indexOf(currentDir); + let next = options[idx + 1] ?? options[0]!; + this.properties["connections_dir"][propIdx] = next; + return; + } + + let next = currentLayout; + do { + let idx = LAYOUT_CLOCKWISE.indexOf(next); + next = LAYOUT_CLOCKWISE[idx + 1] ?? LAYOUT_CLOCKWISE[0]!; + } while (next === oppositeLayout); + this.properties["connections_layout"][propIdx] = next; + this.setDirtyCanvas(true, true); + } + + /** + * Handles a mouse move while this node is selected. Note, though, that the actual work here is + * processed bycause the move and resize shortcuts set `canvas.node_capturing_input` to this node + * when they start (otherwise onMouseMove only fires when the mouse moves within the node's + * bounds). + */ + override onMouseMove(event: PointerEvent): void { + if (this.shortcuts.move.state) { + const shortcut = this.shortcuts.move; + if (shortcut.initialMousePos[0] === -1) { + shortcut.initialMousePos[0] = event.clientX; + shortcut.initialMousePos[1] = event.clientY; + shortcut.initialNodePos[0] = this.pos[0]; + shortcut.initialNodePos[1] = this.pos[1]; + } + this.manuallyHandleMove(event); + } else if (this.shortcuts.resize.state) { + const shortcut = this.shortcuts.resize; + if (shortcut.initialMousePos[0] === -1) { + shortcut.initialMousePos[0] = event.clientX; + shortcut.initialMousePos[1] = event.clientY; + shortcut.initialNodeSize[0] = this.size[0]; + shortcut.initialNodeSize[1] = this.size[1]; + shortcut.initialNodePos[0] = this.pos[0]; + shortcut.initialNodePos[1] = this.pos[1]; + const canvas = app.canvas as TLGraphCanvas; + const offset = canvas.convertEventToCanvasOffset(event); + shortcut.resizeOnSide[0] = this.pos[0] > offset[0] ? LiteGraph.LEFT : LiteGraph.RIGHT; + shortcut.resizeOnSide[1] = this.pos[1] > offset[1] ? LiteGraph.UP : LiteGraph.DOWN; + } + this.manuallyHandleResize(event); + } + } + + /** + * Handles a key down while this node is selected, starting a shortcut if the keys are newly + * pressed. + */ + override onKeyDown(event: KeyboardEvent) { + super.onKeyDown(event); + const canvas = app.canvas as TLGraphCanvas; + + // Only handle shortcuts while we're enabled in the config. + if (CONFIG_FAST_REROUTE_ENABLED) { + for (const [key, shortcut] of Object.entries(this.shortcuts)) { + if (!shortcut.state) { + const keys = KEY_EVENT_SERVICE.areOnlyKeysDown(shortcut.keys); + if (keys) { + shortcut.state = true; + if (key === "rotate") { + this.rotate(90); + } else if (key.includes("connection")) { + this.cycleConnection(key.includes("input") ? IoDirection.INPUT : IoDirection.OUTPUT); + } + if ((shortcut as any).initialMousePos) { + canvas.node_capturing_input = this; + } + } + } + } + } + } + + /** + * Handles a key up while this node is selected, canceling any current shortcut. + */ + override onKeyUp(event: KeyboardEvent) { + super.onKeyUp(event); + const canvas = app.canvas as TLGraphCanvas; + + // Only handle shortcuts while we're enabled in the config. + if (CONFIG_FAST_REROUTE_ENABLED) { + for (const [key, shortcut] of Object.entries(this.shortcuts)) { + if (shortcut.state) { + const keys = KEY_EVENT_SERVICE.areOnlyKeysDown(shortcut.keys); + if (!keys) { + shortcut.state = false; + if ((shortcut as any).initialMousePos) { + (shortcut as any).initialMousePos = [-1, -1]; + if ((canvas.node_capturing_input = this)) { + canvas.node_capturing_input = null; + } + this.setDirtyCanvas(true, true); + } + } + } + } + } + } + + /** + * Handles a deselection of the node, canceling any current shortcut. + */ + override onDeselected(): void { + super.onDeselected?.(); + const canvas = app.canvas as TLGraphCanvas; + for (const [key, shortcut] of Object.entries(this.shortcuts)) { + shortcut.state = false; + if ((shortcut as any).initialMousePos) { + (shortcut as any).initialMousePos = [-1, -1]; + if ((canvas.node_capturing_input = this)) { + canvas.node_capturing_input = null; + } + this.setDirtyCanvas(true, true); + } + } + } + + override onRemoved(): void { + super.onRemoved?.(); + // If we're removed, let's call out to the link dragging above. In a settimeout because this is + // called as we're removing with further cleanup Litegraph does, and we want the handler to + // cleanup further, afterwards + setTimeout(() => { + SERVICE.handleRemovedNodeMaybeWhileDragging(this); + }, 32); + } + + override getHelp() { + return ` +

+ Finally, a comfortable, powerful reroute node with true multi-direction and powerful + shortcuts to bring your workflow to the next level. +

+ + ${ + !CONFIG_FAST_REROUTE_ENABLED + ? `

Fast Shortcuts are currently disabled.` + : ` +

    +
  • + ${CONFIG_KEY_CREATE_WHILE_LINKING} Create a new reroute node while dragging + a link, connecting it to the link in the place and continuing the link. +

  • +
  • + ${CONFIG_KEY_ROTATE} Rotate the selected reroute node counter clockwise 90 + degrees. +

  • +
  • + ${CONFIG_KEY_RESIZE} Resize the selected reroute node from the nearest + corner by holding down and moving your mouse. +

  • +
  • + ${CONFIG_KEY_MOVE} Move the selected reroute node by holding down and + moving your mouse. +

  • +
  • + ${CONFIG_KEY_CXN_INPUT} Change the input layout/direction of the selected + reroute node. +

  • +
  • + ${CONFIG_KEY_CXN_OUTPUT} Change the output layout/direction of the selected + reroute node. +

  • +
+ ` + } +

+ To change, ${!CONFIG_FAST_REROUTE_ENABLED ? "enable" : "disable"} or configure sohrtcuts, + make a copy of + /custom_nodes/rgthree-comfy/rgthree_config.json.default to + /custom_nodes/rgthree-comfy/rgthree_config.json and configure under + nodes > reroute > fast_reroute. +

+ `; + } +} + +addMenuItem(RerouteNode, app, { + name: (node) => `${node.properties?.["showLabel"] ? "Hide" : "Show"} Label/Title`, + property: "showLabel", + callback: async (node, value) => { + app.graph.setDirtyCanvas(true, true); + }, +}); + +addMenuItem(RerouteNode, app, { + name: (node) => `${node.resizable ? "No" : "Allow"} Resizing`, + callback: (node) => { + (node as RerouteNode).setResizable(!node.resizable); + node.size[0] = Math.max(40, node.size[0]); + node.size[1] = Math.max(30, node.size[1]); + (node as RerouteNode).applyNodeSize(); + }, +}); + +addMenuItem(RerouteNode, app, { + name: "Static Width", + property: "size", + subMenuOptions: (() => { + const options = []; + for (let w = 8; w > 0; w--) { + options.push(`${w * 10}`); + } + return options; + })(), + prepareValue: (value, node) => [Number(value), node.size[1]], + callback: (node) => { + (node as RerouteNode).setResizable(false); + (node as RerouteNode).applyNodeSize(); + }, +}); + +addMenuItem(RerouteNode, app, { + name: "Static Height", + property: "size", + subMenuOptions: (() => { + const options = []; + for (let w = 8; w > 0; w--) { + options.push(`${w * 10}`); + } + return options; + })(), + prepareValue: (value, node) => [node.size[0], Number(value)], + callback: (node) => { + (node as RerouteNode).setResizable(false); + (node as RerouteNode).applyNodeSize(); + }, +}); + +addConnectionLayoutSupport( + RerouteNode, + app, + [ + ["Left", "Right"], + ["Left", "Top"], + ["Left", "Bottom"], + ["Right", "Left"], + ["Right", "Top"], + ["Right", "Bottom"], + ["Top", "Left"], + ["Top", "Right"], + ["Top", "Bottom"], + ["Bottom", "Left"], + ["Bottom", "Right"], + ["Bottom", "Top"], + ], + (node) => { + (node as RerouteNode).applyNodeSize(); + }, +); + +addMenuItem(RerouteNode, app, { + name: "Rotate", + subMenuOptions: [ + "Rotate 90° Clockwise", + "Rotate 90° Counter-Clockwise", + "Rotate 180°", + null, + "Flip Horizontally", + "Flip Vertically", + ], + callback: (node_: TLGraphNode, value) => { + const node = node_ as RerouteNode; + if (value?.startsWith("Rotate 90° Clockwise")) { + node.rotate(90); + } else if (value?.startsWith("Rotate 90° Counter-Clockwise")) { + node.rotate(-90); + } else if (value?.startsWith("Rotate 180°")) { + node.rotate(180); + } else { + const inputDirIndex = LAYOUT_CLOCKWISE.indexOf(node.properties["connections_layout"][0]); + const outputDirIndex = LAYOUT_CLOCKWISE.indexOf(node.properties["connections_layout"][1]); + if (value?.startsWith("Flip Horizontally")) { + if (["Left", "Right"].includes(node.properties["connections_layout"][0])) { + node.properties["connections_layout"][0] = + LAYOUT_CLOCKWISE[(((inputDirIndex + 2) % 4) + 4) % 4]; + } + if (["Left", "Right"].includes(node.properties["connections_layout"][1])) { + node.properties["connections_layout"][1] = + LAYOUT_CLOCKWISE[(((outputDirIndex + 2) % 4) + 4) % 4]; + } + } else if (value?.startsWith("Flip Vertically")) { + if (["Top", "Bottom"].includes(node.properties["connections_layout"][0])) { + node.properties["connections_layout"][0] = + LAYOUT_CLOCKWISE[(((inputDirIndex + 2) % 4) + 4) % 4]; + } + if (["Top", "Bottom"].includes(node.properties["connections_layout"][1])) { + node.properties["connections_layout"][1] = + LAYOUT_CLOCKWISE[(((outputDirIndex + 2) % 4) + 4) % 4]; + } + } + } + }, +}); + +addMenuItem(RerouteNode, app, { + name: "Clone New Reroute...", + subMenuOptions: ["Before", "After"], + callback: async (node, value) => { + const clone = node.clone(); + const pos = [...node.pos]; + if (value === "Before") { + clone.pos = [pos[0]! - 20, pos[1]! - 20]; + app.graph.add(clone); + await wait(); + const inputLinks = getSlotLinks(node.inputs[0]); + for (const inputLink of inputLinks) { + const link = inputLink.link; + const linkedNode = app.graph.getNodeById(link.origin_id) as TLGraphNode; + if (linkedNode) { + linkedNode.connect(0, clone, 0); + } + } + clone.connect(0, node, 0); + } else { + clone.pos = [pos[0]! + 20, pos[1]! + 20]; + app.graph.add(clone); + await wait(); + const outputLinks = getSlotLinks(node.outputs[0]); + node.connect(0, clone, 0); + for (const outputLink of outputLinks) { + const link = outputLink.link; + const linkedNode = app.graph.getNodeById(link.target_id) as TLGraphNode; + if (linkedNode) { + clone.connect(0, linkedNode, link.target_slot); + } + } + } + }, +}); + +app.registerExtension({ + name: "rgthree.Reroute", + registerCustomNodes() { + RerouteNode.setUp(); + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/rgthree.scss b/rgthree-comfy/src_web/comfyui/rgthree.scss new file mode 100644 index 0000000000000000000000000000000000000000..107e924192ba9dabf17d71a2058cd5dcad476e58 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/rgthree.scss @@ -0,0 +1,295 @@ +.rgthree-top-messages-container { + position: fixed; + z-index: 9999; + top: 0; + left: 0; + width: 100%; + height: 0; + display: flex; + flex-direction: column; + align-items: center; + justify-content: start; +} + +.rgthree-top-messages-container > div { + position: relative; + height: fit-content; + padding: 4px; + margin-top: -100px; /* re-set by JS */ + opacity: 0; + transition: all 0.33s ease-in-out; + z-index: 3; +} +.rgthree-top-messages-container > div:last-child { + z-index: 2; +} +.rgthree-top-messages-container > div:not(.-show) { + z-index: 1; +} + +.rgthree-top-messages-container > div.-show { + opacity: 1; + margin-top: 0px !important; +} + +.rgthree-top-messages-container > div.-show { + opacity: 1; + transform: translateY(0%); +} + +.rgthree-top-messages-container > div > div { + position: relative; + background: #353535; + color: #fff; + display: flex; + flex-direction: row; + align-items: center; + justify-content: center; + height: fit-content; + box-shadow: 0 0 10px rgba(0, 0, 0, 0.88); + padding: 6px 12px; + border-radius: 4px; + font-family: Arial, sans-serif; + font-size: 14px; +} +.rgthree-top-messages-container > div > div > span { + display: flex; + flex-direction: row; + align-items: center; + justify-content: center; +} +.rgthree-top-messages-container > div > div > span svg { + width: 20px; + height: auto; + margin-right: 8px; +} +.rgthree-top-messages-container > div > div > span svg.icon-checkmark { + fill: #2e9720; +} + +.rgthree-top-messages-container [type="warn"]::before, +.rgthree-top-messages-container [type="success"]::before { + content: '⚠️'; + display: inline-block; + flex: 0 0 auto; + font-size: 18px; + margin-right: 4px; + line-height: 1; +} +.rgthree-top-messages-container [type="success"]::before { + content: '🎉'; +} + +.rgthree-top-messages-container a { + cursor: pointer; + text-decoration: underline; + color: #fc0; + margin-left: 4px; + display: inline-block; + line-height: 1; +} + +.rgthree-top-messages-container a:hover { + color: #fc0; + text-decoration: none; +} + +/* Fix node selector being crazy long b/c of array types. */ +.litegraph.litesearchbox input, +.litegraph.litesearchbox select { + max-width: 250px; +} + +/* There's no reason for this z-index to be so high. It layers on top of things it shouldn't, + (like pythongssss' image gallery, the properties panel, etc.) */ +.comfy-multiline-input { + z-index: 1 !important; +} +.comfy-multiline-input:focus { + z-index: 2 !important; +} +.litegraph .dialog { + z-index: 3 !important; /* This is set to 1, but goes under the multi-line inputs, so bump it. */ +} + + +@import '../common/css/buttons.scss'; +@import '../common/css/dialog.scss'; +@import '../common/css/menu.scss'; + +.rgthree-dialog.-settings { + width: 100%; +} +.rgthree-dialog.-settings fieldset { + border: 1px solid rgba(255, 255, 255, 0.25); + padding: 0 12px 8px; + margin-bottom: 16px; +} +.rgthree-dialog.-settings fieldset > legend { + margin-left: 8px; + padding: 0 8px; + opacity: 0.5; +} +.rgthree-dialog.-settings .formrow { + display: flex; + flex-direction: column; +} +.rgthree-dialog.-settings .formrow + .formrow { + border-top: 1px solid rgba(255, 255, 255, 0.25); +} +.rgthree-dialog.-settings .fieldrow { + display: flex; + flex-direction: row; +} +.rgthree-dialog.-settings .fieldrow > label { + flex: 1 1 auto; + user-select: none; + padding: 8px 12px 12px; +} +.rgthree-dialog.-settings .fieldrow > label span { + font-weight: bold; +} +.rgthree-dialog.-settings .fieldrow > label small { + display: block; + margin-top: 4px; + font-size: calc(11rem / 16); + opacity: 0.75; + padding-left: 16px; +} +.rgthree-dialog.-settings .fieldrow ~ .fieldrow { + font-size: 0.9rem; + border-top: 1px dotted rgba(255, 255, 255, 0.25); +} +.rgthree-dialog.-settings .fieldrow ~ .fieldrow label { + padding-left: 28px; +} +.rgthree-dialog.-settings .fieldrow:first-child:not(.-checked) ~ .fieldrow { + display: none; +} +.rgthree-dialog.-settings .fieldrow:hover { + background: rgba(255,255,255,0.1); +} +.rgthree-dialog.-settings .fieldrow ~ .fieldrow span { + font-weight: normal; +} + +.rgthree-dialog.-settings .fieldrow > .fieldrow-value { + display: flex; + align-items: center; + justify-content: end; + flex: 0 0 auto; + width: 50%; + max-width: 230px; +} +.rgthree-dialog.-settings .fieldrow.-type-boolean > .fieldrow-value { + max-width: 64px; +} +.rgthree-dialog.-settings .fieldrow.-type-number input { + width: 48px; + text-align: right; +} + +.rgthree-dialog.-settings .fieldrow input[type="checkbox"] { + width: 24px; + height: 24px; + cursor: pointer; +} + +.rgthree-comfyui-settings-row div { + display: flex; + flex-direction: row; + align-items: center; + justify-content: end; +} +.rgthree-comfyui-settings-row div svg { + width: 36px; + height: 36px; + margin-right: 16px; +} + + +.litegraph.litecontextmenu .litemenu-title .rgthree-contextmenu-title-rgthree-comfy, +.litegraph.litecontextmenu .litemenu-entry.rgthree-contextmenu-item { + display: flex; + flex-direction: row; + align-items: center; + justify-content: start; +} + +.litegraph.litecontextmenu .litemenu-title .rgthree-contextmenu-title-rgthree-comfy svg, +.litegraph.litecontextmenu .litemenu-entry.rgthree-contextmenu-item svg { + fill: currentColor; + width: auto; + height: 16px; + margin-right: 6px; +} +.litegraph.litecontextmenu .litemenu-entry.rgthree-contextmenu-item svg.github-star { + fill: rgb(227, 179, 65); +} + +.litegraph.litecontextmenu .litemenu-title .rgthree-contextmenu-title-rgthree-comfy, +.litegraph.litecontextmenu .litemenu-entry.rgthree-contextmenu-label { + color: #dde; + background-color: #212121 !important; + margin: 0; + padding: 2px; + cursor: default; + opacity: 1; + padding: 4px; + font-weight: bold; +} +.litegraph.litecontextmenu .litemenu-title .rgthree-contextmenu-title-rgthree-comfy { + font-size: 1.1em; + color: #fff; + background-color: #090909 !important; + justify-content: center; + padding: 4px 8px; +} + +rgthree-progress-bar { + display: block; + position: relative; + z-index: 999; + top: 0; + left: 0; + height: 14px; + font-size: 10px; + width: 100%; + overflow: hidden; + box-shadow: 0px 0px 3px rgba(0, 0, 0, 0.25); + box-shadow: + inset 0px -1px 0px rgba(0, 0, 0, 0.25), + 0px 1px 0px rgba(255, 255, 255, 0.125); + +} + +* ~ rgthree-progress-bar, +.comfyui-body-bottom rgthree-progress-bar { + box-shadow: + 0px -1px 0px rgba(0, 0, 0, 1), + inset 0px 1px 0px rgba(255, 255, 255, 0.15), inset 0px -1px 0px rgba(0, 0, 0, 0.25), 0px 1px 0px rgba(255, 255, 255, 0.125); +} + +body:not([style*=grid]){ + rgthree-progress-bar { + position: fixed; + top: 0px; + bottom: auto; + } + rgthree-progress-bar.rgthree-pos-bottom { + top: auto; + bottom: 0px; + } +} + + +.rgthree-debug-keydowns { + display: block; + position: fixed; + z-index: 999; + top: 3px; + right: 8px; + font-size: 10px; + color: #fff; + font-family: sans-serif; +} \ No newline at end of file diff --git a/rgthree-comfy/src_web/comfyui/rgthree.ts b/rgthree-comfy/src_web/comfyui/rgthree.ts new file mode 100644 index 0000000000000000000000000000000000000000..165491b526fa6a48fa08eb44d3703ebfde6d4b0a --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/rgthree.ts @@ -0,0 +1,996 @@ +import type { + LGraphCanvas as TLGraphCanvas, + LGraphNode, + SerializedLGraphNode, + serializedLGraph, + ContextMenuItem, + LGraph as TLGraph, + AdjustedMouseEvent, + IContextMenuOptions, +} from "typings/litegraph.js"; +import type { ComfyApiFormat, ComfyApiPrompt, ComfyApp } from "typings/comfy.js"; +import { app } from "scripts/app.js"; +import { api } from "scripts/api.js"; +import { SERVICE as CONFIG_SERVICE } from "./services/config_service.js"; +import { fixBadLinks } from "rgthree/common/link_fixer.js"; +import { injectCss, wait } from "rgthree/common/shared_utils.js"; +import { replaceNode, waitForCanvas, waitForGraph } from "./utils.js"; +import { NodeTypesString, addRgthree, getNodeTypeStrings, stripRgthree } from "./constants.js"; +import { RgthreeProgressBar } from "rgthree/common/progress_bar.js"; +import { RgthreeConfigDialog } from "./config.js"; +import { + iconGear, + iconNode, + iconReplace, + iconStarFilled, + logoRgthree, +} from "rgthree/common/media/svgs.js"; +import type { Bookmark } from "./bookmark.js"; +import { createElement, query, queryOne } from "rgthree/common/utils_dom.js"; + +export enum LogLevel { + IMPORTANT = 1, + ERROR, + WARN, + INFO, + DEBUG, + DEV, +} + +const LogLevelKeyToLogLevel: { [key: string]: LogLevel } = { + IMPORTANT: LogLevel.IMPORTANT, + ERROR: LogLevel.ERROR, + WARN: LogLevel.WARN, + INFO: LogLevel.INFO, + DEBUG: LogLevel.DEBUG, + DEV: LogLevel.DEV, +}; + +type ConsoleLogFns = "log" | "error" | "warn" | "debug" | "info"; +const LogLevelToMethod: { [key in LogLevel]: ConsoleLogFns } = { + [LogLevel.IMPORTANT]: "log", + [LogLevel.ERROR]: "error", + [LogLevel.WARN]: "warn", + [LogLevel.INFO]: "info", + [LogLevel.DEBUG]: "log", + [LogLevel.DEV]: "log", +}; +const LogLevelToCSS: { [key in LogLevel]: string } = { + [LogLevel.IMPORTANT]: "font-weight: bold; color: blue;", + [LogLevel.ERROR]: "", + [LogLevel.WARN]: "", + [LogLevel.INFO]: "font-style: italic; color: blue;", + [LogLevel.DEBUG]: "font-style: italic; color: #444;", + [LogLevel.DEV]: "color: #004b68;", +}; + +let GLOBAL_LOG_LEVEL = LogLevel.ERROR; + +/** + * A blocklist of extensions to disallow hooking into rgthree's base classes when calling the + * `rgthree.invokeExtensionsAsync` method (which runs outside of ComfyNode's + * `app.invokeExtensionsAsync` which is private). + * + * In Apr 2024 the base rgthree node class added support for other extensions using `nodeCreated` + * and `beforeRegisterNodeDef` which allows other extensions to modify the class. However, since it + * had been months since divorcing the ComfyNode in rgthree-comfy due to instability and + * inflexibility, this was a bit risky as other extensions hadn't ever run with this ability. This + * list attempts to block extensions from being able to call into rgthree-comfy nodes via the + * `nodeCreated` and `beforeRegisterNodeDef` callbacks now that rgthree-comfy is utilizing them + * because they do not work. Oddly, it's ComfyUI's own extension that is broken. + */ +const INVOKE_EXTENSIONS_BLOCKLIST = [ + { + name: "Comfy.WidgetInputs", + reason: + "Major conflict with rgthree-comfy nodes' inputs causing instability and " + + "repeated link disconnections.", + }, + { + name: "efficiency.widgethider", + reason: + "Overrides value getter before widget getter is prepared. Can be lifted if/when " + + "https://github.com/jags111/efficiency-nodes-comfyui/pull/203 is pulled.", + }, +]; + +/** A basic wrapper around logger. */ +class Logger { + /** Logs a message to the console if it meets the current log level. */ + log(level: LogLevel, message: string, ...args: any[]) { + const [n, v] = this.logParts(level, message, ...args); + console[n]?.(...v); + } + + /** + * Returns a tuple of the console function and its arguments. Useful for callers to make the + * actual console. call to gain benefits of DevTools knowing the source line. + * + * If the input is invalid or the level doesn't meet the configuration level, then the return + * value is an unknown function and empty set of values. Callers can use optionla chaining + * successfully: + * + * const [fn, values] = logger.logPars(LogLevel.INFO, 'my message'); + * console[fn]?.(...values); // Will work even if INFO won't be logged. + * + */ + logParts(level: LogLevel, message: string, ...args: any[]): [ConsoleLogFns, any[]] { + if (level <= GLOBAL_LOG_LEVEL) { + const css = LogLevelToCSS[level] || ""; + if (level === LogLevel.DEV) { + message = `🔧 ${message}`; + } + return [LogLevelToMethod[level], [`%c${message}`, css, ...args]]; + } + return ["none" as "info", []]; + } +} + +/** + * A log session, with the name as the prefix. A new session will stack prefixes. + */ +class LogSession { + readonly logger = new Logger(); + readonly logsCache: { [key: string]: { lastShownTime: number } } = {}; + + constructor(readonly name?: string) {} + + /** + * Returns the console log method to use and the arguments to pass so the call site can log from + * there. This extra work at the call site allows for easier debugging in the dev console. + * + * const [logMethod, logArgs] = logger.logParts(LogLevel.DEBUG, message, ...args); + * console[logMethod]?.(...logArgs); + */ + logParts(level: LogLevel, message?: string, ...args: any[]): [ConsoleLogFns, any[]] { + message = `${this.name || ""}${message ? " " + message : ""}`; + return this.logger.logParts(level, message, ...args); + } + + logPartsOnceForTime( + level: LogLevel, + time: number, + message?: string, + ...args: any[] + ): [ConsoleLogFns, any[]] { + message = `${this.name || ""}${message ? " " + message : ""}`; + const cacheKey = `${level}:${message}`; + const cacheEntry = this.logsCache[cacheKey]; + const now = +new Date(); + if (cacheEntry && cacheEntry.lastShownTime + time > now) { + return ["none" as "info", []]; + } + const parts = this.logger.logParts(level, message, ...args); + if (console[parts[0]]) { + this.logsCache[cacheKey] = this.logsCache[cacheKey] || ({} as { lastShownTime: number }); + this.logsCache[cacheKey]!.lastShownTime = now; + } + return parts; + } + + debugParts(message?: string, ...args: any[]) { + return this.logParts(LogLevel.DEBUG, message, ...args); + } + + infoParts(message?: string, ...args: any[]) { + return this.logParts(LogLevel.INFO, message, ...args); + } + + warnParts(message?: string, ...args: any[]) { + return this.logParts(LogLevel.WARN, message, ...args); + } + + newSession(name?: string) { + return new LogSession(`${this.name}${name}`); + } +} + +export type RgthreeUiMessage = { + id: string; + message: string; + type?: "warn" | "info" | "success" | null; + timeout?: number; + // closeable?: boolean; // TODO + actions?: Array<{ + label: string; + href?: string; + callback?: (event: MouseEvent) => void; + }>; +}; + +/** + * A global class as 'rgthree'; exposed on wiindow. Lots can go in here. + */ +class Rgthree extends EventTarget { + /** Exposes the ComfyUI api instance on rgthree. */ + readonly api = api; + private settingsDialog: RgthreeConfigDialog | null = null; + private progressBarEl: RgthreeProgressBar | null = null; + private rgthreeCssPromise: Promise; + + /** Stores a node id that we will use to queu only that output node (with `queueOutputNode`). */ + private queueNodeIds: number[] | null = null; + + logger = new LogSession("[rgthree]"); + + monitorBadLinksAlerted = false; + monitorLinkTimeout: number | null = null; + + processingQueue = false; + loadingApiJson = false; + replacingReroute: number | null = null; + processingMouseDown = false; + processingMouseUp = false; + processingMouseMove = false; + lastAdjustedMouseEvent: AdjustedMouseEvent | null = null; + + // Comfy/LiteGraph states so nodes and tell what the hell is going on. + canvasCurrentlyCopyingToClipboard = false; + canvasCurrentlyCopyingToClipboardWithMultipleNodes = false; + initialGraphToPromptSerializedWorkflowBecauseComfyUIBrokeStuff: any = null; + + private elDebugKeydowns: HTMLDivElement | null = null; + + private readonly isMac: boolean = !!( + navigator.platform?.toLocaleUpperCase().startsWith("MAC") || + (navigator as any).userAgentData?.platform?.toLocaleUpperCase().startsWith("MAC") + ); + + constructor() { + super(); + + const logLevel = + LogLevelKeyToLogLevel[CONFIG_SERVICE.getConfigValue("log_level")] ?? GLOBAL_LOG_LEVEL; + this.setLogLevel(logLevel); + + this.initializeGraphAndCanvasHooks(); + this.initializeComfyUIHooks(); + this.initializeContextMenu(); + + this.rgthreeCssPromise = injectCss("extensions/rgthree-comfy/rgthree.css"); + + this.initializeProgressBar(); + + CONFIG_SERVICE.addEventListener("config-change", ((e: CustomEvent) => { + if (e.detail?.key?.includes("features.progress_bar")) { + this.initializeProgressBar(); + } + }) as EventListener); + } + + /** + * Initializes the top progress bar, if it's configured. + */ + async initializeProgressBar() { + if (CONFIG_SERVICE.getConfigValue("features.progress_bar.enabled")) { + await this.rgthreeCssPromise; + if (!this.progressBarEl) { + this.progressBarEl = RgthreeProgressBar.create(); + this.progressBarEl.setAttribute( + "title", + "Progress Bar by rgthree. right-click for rgthree menu.", + ); + + this.progressBarEl.addEventListener("contextmenu", async (e) => { + e.stopPropagation(); + e.preventDefault(); + }); + + this.progressBarEl.addEventListener("pointerdown", async (e) => { + LiteGraph.closeAllContextMenus(); + if (e.button == 2) { + const canvas = await waitForCanvas(); + new LiteGraph.ContextMenu( + this.getRgthreeContextMenuItems(), + { + title: `
${logoRgthree} rgthree-comfy
`, + left: e.clientX, + top: 5, + }, + canvas.getCanvasWindow(), + ); + return; + } + if (e.button == 0) { + const nodeId = this.progressBarEl?.currentNodeId; + if (nodeId) { + const [canvas, graph] = await Promise.all([waitForCanvas(), waitForGraph()]); + const node = graph.getNodeById(Number(nodeId)); + if (node) { + canvas.centerOnNode(node); + e.stopPropagation(); + e.preventDefault(); + } + } + return; + } + }); + } + // Handle both cases in case someone hasn't updated. Can probably just assume + // `isUpdatedComfyBodyClasses` is true in the near future. + const isUpdatedComfyBodyClasses = !!queryOne(".comfyui-body-top"); + const position = CONFIG_SERVICE.getConfigValue("features.progress_bar.position"); + this.progressBarEl.classList.toggle("rgthree-pos-bottom", position === "bottom"); + // If ComfyUI is updated with the body segments, then use that. + if (isUpdatedComfyBodyClasses) { + if (position === "bottom") { + queryOne(".comfyui-body-bottom")!.appendChild(this.progressBarEl); + } else { + queryOne(".comfyui-body-top")!.appendChild(this.progressBarEl); + } + } else { + document.body.appendChild(this.progressBarEl); + } + const height = CONFIG_SERVICE.getConfigValue("features.progress_bar.height") || 14; + this.progressBarEl.style.height = `${height}px`; + const fontSize = Math.max(10, Number(height) - 10); + this.progressBarEl.style.fontSize = `${fontSize}px`; + this.progressBarEl.style.fontWeight = fontSize <= 12 ? "bold" : "normal"; + } else { + this.progressBarEl?.remove(); + } + } + + /** + * Initialize a bunch of hooks into LiteGraph itself so we can either keep state or context on + * what's happening so nodes can respond appropriately. This is usually to fix broken assumptions + * in the unowned code [🤮], but sometimes to add features or enhancements too [⭐]. + */ + private async initializeGraphAndCanvasHooks() { + const rgthree = this; + + // [🤮] To mitigate changes from https://github.com/rgthree/rgthree-comfy/issues/69 + // and https://github.com/comfyanonymous/ComfyUI/issues/2193 we can try to store the workflow + // node so our nodes can find the seralized node. Works with method + // `getNodeFromInitialGraphToPromptSerializedWorkflowBecauseComfyUIBrokeStuff` to find a node + // while serializing. What a way to work around... + const graphSerialize = LGraph.prototype.serialize; + LGraph.prototype.serialize = function () { + const response = graphSerialize.apply(this, [...arguments] as any) as any; + rgthree.initialGraphToPromptSerializedWorkflowBecauseComfyUIBrokeStuff = response; + return response; + }; + + // Overrides LiteGraphs' processMouseDown to both keep state as well as dispatch a custom event. + const processMouseDown = LGraphCanvas.prototype.processMouseDown; + LGraphCanvas.prototype.processMouseDown = function (e: AdjustedMouseEvent) { + rgthree.processingMouseDown = true; + const returnVal = processMouseDown.apply(this, [...arguments] as any); + rgthree.dispatchCustomEvent("on-process-mouse-down", { originalEvent: e }); + rgthree.processingMouseDown = false; + return returnVal; + }; + + // Overrides LiteGraph's `adjustMouseEvent` to capture the last even coming in and out. Useful + // to capture the last `canvasX` and `canvasY` properties, which are not the same as LiteGraph's + // `canvas.last_mouse_position`, unfortunately. + const adjustMouseEvent = LGraphCanvas.prototype.adjustMouseEvent; + LGraphCanvas.prototype.adjustMouseEvent = function (e: PointerEvent) { + adjustMouseEvent.apply(this, [...arguments] as any); + rgthree.lastAdjustedMouseEvent = e as AdjustedMouseEvent; + }; + + // [🤮] Copying to clipboard clones nodes and then manipulats the linking data manually which + // does not allow a node to handle connections. This harms nodes that manually handle inputs, + // like our any-input nodes that may start with one input, and manually add new ones when one is + // attached. + const copyToClipboard = LGraphCanvas.prototype.copyToClipboard; + LGraphCanvas.prototype.copyToClipboard = function (nodes: LGraphNode[]) { + rgthree.canvasCurrentlyCopyingToClipboard = true; + rgthree.canvasCurrentlyCopyingToClipboardWithMultipleNodes = + Object.values(nodes || this.selected_nodes || []).length > 1; + copyToClipboard.apply(this, [...arguments] as any); + rgthree.canvasCurrentlyCopyingToClipboard = false; + rgthree.canvasCurrentlyCopyingToClipboardWithMultipleNodes = false; + }; + + // [⭐] Make it so when we add a group, we get to name it immediately. + const onGroupAdd = LGraphCanvas.onGroupAdd; + LGraphCanvas.onGroupAdd = function (...args: any[]) { + const graph = app.graph as TLGraph; + onGroupAdd.apply(this, [...args] as any); + LGraphCanvas.onShowPropertyEditor( + {}, + null, + null, + null, + graph._groups[graph._groups.length - 1], + ); + }; + } + + /** + * [🤮] Handles the same exact thing as ComfyApp's `invokeExtensionsAsync`, but done here since + * it is #private in ComfyApp because... of course it us. This is necessary since we purposefully + * avoid using the ComfyNode due to historical instability and inflexibility for all the advanced + * ui stuff rgthree-comfy nodes do, but we can still have other custom nodes know what's happening + * with rgthree-comfy; specifically, for `nodeCreated` as of now. + */ + async invokeExtensionsAsync(method: "nodeCreated", ...args: any[]) { + const comfyapp = app as ComfyApp; + if (CONFIG_SERVICE.getConfigValue("features.invoke_extensions_async.node_created") === false) { + const [m, a] = this.logParts( + LogLevel.INFO, + `Skipping invokeExtensionsAsync for applicable rgthree-comfy nodes`, + ); + console[m]?.(...a); + return Promise.resolve(); + } + return await Promise.all( + comfyapp.extensions.map(async (ext) => { + if (ext?.[method]) { + try { + const blocked = INVOKE_EXTENSIONS_BLOCKLIST.find((block) => + ext.name.toLowerCase().startsWith(block.name.toLowerCase()), + ); + if (blocked) { + const [n, v] = this.logger.logPartsOnceForTime( + LogLevel.WARN, + 5000, + `Blocked extension '${ext.name}' method '${method}' for rgthree-nodes because: ${blocked.reason}`, + ); + console[n]?.(...v); + return Promise.resolve(); + } + return await (ext[method] as Function)(...args, comfyapp); + } catch (error) { + const [n, v] = this.logParts( + LogLevel.ERROR, + `Error calling extension '${ext.name}' method '${method}' for rgthree-node.`, + { error }, + { extension: ext }, + { args }, + ); + console[n]?.(...v); + } + } + }), + ); + } + + /** + * Wraps `dispatchEvent` for easier CustomEvent dispatching. + */ + private dispatchCustomEvent(event: string, detail?: any) { + if (detail != null) { + return this.dispatchEvent(new CustomEvent(event, { detail })); + } + return this.dispatchEvent(new CustomEvent(event)); + } + + /** + * Initializes hooks specific to an rgthree-comfy context menu on the root menu. + */ + private async initializeContextMenu() { + const that = this; + setTimeout(async () => { + const getCanvasMenuOptions = LGraphCanvas.prototype.getCanvasMenuOptions; + LGraphCanvas.prototype.getCanvasMenuOptions = function (...args: any[]) { + let existingOptions = getCanvasMenuOptions.apply(this, [...args] as any); + + const options = []; + options.push(null); // Divider + options.push(null); // Divider + options.push(null); // Divider + options.push({ + content: logoRgthree + `rgthree-comfy`, + className: "rgthree-contextmenu-item rgthree-contextmenu-main-item-rgthree-comfy", + submenu: { + options: that.getRgthreeContextMenuItems(), + }, + }); + options.push(null); // Divider + options.push(null); // Divider + + let idx = null; + idx = idx || existingOptions.findIndex((o) => o?.content?.startsWith?.("Queue Group")) + 1; + idx = + idx || existingOptions.findIndex((o) => o?.content?.startsWith?.("Queue Selected")) + 1; + idx = idx || existingOptions.findIndex((o) => o?.content?.startsWith?.("Convert to Group")); + idx = idx || existingOptions.findIndex((o) => o?.content?.startsWith?.("Arrange (")); + idx = idx || existingOptions.findIndex((o) => !o) + 1; + idx = idx || 3; + existingOptions.splice(idx, 0, ...options); + for (let i = existingOptions.length; i > 0; i--) { + if (existingOptions[i] === null && existingOptions[i + 1] === null) { + existingOptions.splice(i, 1); + } + } + + return existingOptions; + }; + }, 1016); + } + + /** + * Returns the standard menu items for an rgthree-comfy context menu. + */ + private getRgthreeContextMenuItems(): ContextMenuItem[] { + const [canvas, graph] = [app.canvas as TLGraphCanvas, app.graph as TLGraph]; + const selectedNodes = Object.values(canvas.selected_nodes || {}); + let rerouteNodes: LGraphNode[] = []; + if (selectedNodes.length) { + rerouteNodes = selectedNodes.filter((n) => n.type === "Reroute"); + } else { + rerouteNodes = graph._nodes.filter((n) => n.type == "Reroute"); + } + const rerouteLabel = selectedNodes.length ? "selected" : "all"; + + const showBookmarks = CONFIG_SERVICE.getFeatureValue("menu_bookmarks.enabled"); + const bookmarkMenuItems = showBookmarks ? getBookmarks() : []; + + return [ + { + content: "Nodes", + disabled: true, + className: "rgthree-contextmenu-item rgthree-contextmenu-label", + }, + { + content: iconNode + "All", + className: "rgthree-contextmenu-item", + has_submenu: true, + submenu: { + options: getNodeTypeStrings() as unknown as ContextMenuItem[], + callback: ( + value: string | ContextMenuItem, + options: IContextMenuOptions, + event: MouseEvent, + ) => { + const node = LiteGraph.createNode(addRgthree(value as string)); + node.pos = [ + rgthree.lastAdjustedMouseEvent!.canvasX, + rgthree.lastAdjustedMouseEvent!.canvasY, + ]; + canvas.graph.add(node); + canvas.selectNode(node); + app.graph.setDirtyCanvas(true, true); + }, + extra: { rgthree_doNotNest: true }, + }, + }, + + { + content: "Actions", + disabled: true, + className: "rgthree-contextmenu-item rgthree-contextmenu-label", + }, + { + content: iconGear + "Settings (rgthree-comfy)", + disabled: !!this.settingsDialog, + className: "rgthree-contextmenu-item", + callback: (...args: any[]) => { + this.settingsDialog = new RgthreeConfigDialog().show(); + this.settingsDialog.addEventListener("close", (e) => { + this.settingsDialog = null; + }); + }, + }, + { + content: iconReplace + ` Convert ${rerouteLabel} Reroutes`, + disabled: !rerouteNodes.length, + className: "rgthree-contextmenu-item", + callback: (...args: any[]) => { + const msg = + `Convert ${rerouteLabel} ComfyUI Reroutes to Reroute (rgthree) nodes? \n` + + `(First save a copy of your workflow & check reroute connections afterwards)`; + if (!window.confirm(msg)) { + return; + } + (async () => { + for (const node of [...rerouteNodes]) { + if (node.type == "Reroute") { + this.replacingReroute = node.id; + await replaceNode(node, NodeTypesString.REROUTE); + this.replacingReroute = null; + } + } + })(); + }, + }, + ...bookmarkMenuItems, + { + content: "More...", + disabled: true, + className: "rgthree-contextmenu-item rgthree-contextmenu-label", + }, + { + content: iconStarFilled + "Star on Github", + className: "rgthree-contextmenu-item rgthree-contextmenu-github", + callback: (...args: any[]) => { + window.open("https://github.com/rgthree/rgthree-comfy", "_blank"); + }, + }, + ]; + } + + /** + * Wraps an `app.queuePrompt` call setting a specific node id that we will inspect and change the + * serialized graph right before being sent (below, in our `api.queuePrompt` override). + */ + async queueOutputNodes(nodeIds: number[]) { + try { + this.queueNodeIds = nodeIds; + await app.queuePrompt(); + } catch (e) { + const [n, v] = this.logParts( + LogLevel.ERROR, + `There was an error queuing nodes ${nodeIds}`, + e, + ); + console[n]?.(...v); + } finally { + this.queueNodeIds = null; + } + } + + /** + * Recusively walks backwards from a node adding its inputs to the `newOutput` from `oldOutput`. + */ + private recursiveAddNodes(nodeId: string, oldOutput: ComfyApiFormat, newOutput: ComfyApiFormat) { + let currentId = nodeId; + let currentNode = oldOutput[currentId]!; + if (newOutput[currentId] == null) { + newOutput[currentId] = currentNode; + for (const inputValue of Object.values(currentNode.inputs || [])) { + if (Array.isArray(inputValue)) { + this.recursiveAddNodes(inputValue[0], oldOutput, newOutput); + } + } + } + return newOutput; + } + + /** + * Initialize a bunch of hooks into ComfyUI and/or LiteGraph itself so we can either keep state or + * context on what's happening so nodes can respond appropriately. This is usually to fix broken + * assumptions in the unowned code [🤮], but sometimes to add features or enhancements too [⭐]. + */ + private initializeComfyUIHooks() { + const rgthree = this; + + // Keep state for when the app is queuing the prompt. For instance, this is used for seed to + // understand if we're serializing because we're queueing (and return the random seed to use) or + // for saving the workflow (and keep -1, etc.). + const queuePrompt = app.queuePrompt as Function; + app.queuePrompt = async function () { + rgthree.processingQueue = true; + rgthree.dispatchCustomEvent("queue"); + try { + await queuePrompt.apply(app, [...arguments]); + } finally { + rgthree.processingQueue = false; + rgthree.dispatchCustomEvent("queue-end"); + } + }; + + // Keep state for when the app is in the middle of loading from an api JSON file. + const loadApiJson = app.loadApiJson; + app.loadApiJson = async function () { + rgthree.loadingApiJson = true; + try { + loadApiJson.apply(app, [...arguments] as any); + } finally { + rgthree.loadingApiJson = false; + } + }; + + // Keep state for when the app is serizalizing the graph to prompt. + const graphToPrompt = app.graphToPrompt; + app.graphToPrompt = async function () { + rgthree.dispatchCustomEvent("graph-to-prompt"); + let promise = graphToPrompt.apply(app, [...arguments] as any); + await promise; + rgthree.dispatchCustomEvent("graph-to-prompt-end"); + return promise; + }; + + // Override the queuePrompt for api to intercept the prompt output and, if queueNodeIds is set, + // then we only want to queue those nodes, by rewriting the api format (prompt 'output' field) + // so only those are evaluated. + const apiQueuePrompt = api.queuePrompt as Function; + api.queuePrompt = async function (index: number, prompt: ComfyApiPrompt) { + if (rgthree.queueNodeIds?.length && prompt.output) { + const oldOutput = prompt.output; + let newOutput = {}; + for (const queueNodeId of rgthree.queueNodeIds) { + rgthree.recursiveAddNodes(String(queueNodeId), oldOutput, newOutput); + } + prompt.output = newOutput; + } + rgthree.dispatchCustomEvent("comfy-api-queue-prompt-before", { + workflow: prompt.workflow, + output: prompt.output, + }); + const response = apiQueuePrompt.apply(app, [index, prompt]); + rgthree.dispatchCustomEvent("comfy-api-queue-prompt-end"); + return response; + }; + + // Hook into a clean call; allow us to clear and rgthree messages. + const clean = app.clean; + app.clean = function () { + rgthree.clearAllMessages(); + clean && clean.apply(app, [...arguments] as any); + }; + + // Hook into a data load, like from an image or JSON drop-in. This is (currently) used to + // monitor for bad linking data. + const loadGraphData = app.loadGraphData; + app.loadGraphData = function (graph: serializedLGraph) { + if (rgthree.monitorLinkTimeout) { + clearTimeout(rgthree.monitorLinkTimeout); + rgthree.monitorLinkTimeout = null; + } + rgthree.clearAllMessages(); + // Try to make a copy to use, because ComfyUI's loadGraphData will modify it. + let graphCopy: serializedLGraph | null; + try { + graphCopy = JSON.parse(JSON.stringify(graph)); + } catch (e) { + graphCopy = null; + } + setTimeout(() => { + const wasLoadingAborted = document + .querySelector(".comfy-modal-content") + ?.textContent?.includes("Loading aborted due"); + const graphToUse = wasLoadingAborted ? graphCopy || graph : app.graph; + const fixBadLinksResult = fixBadLinks(graphToUse as unknown as TLGraph); + if (fixBadLinksResult.hasBadLinks) { + const [n, v] = rgthree.logParts( + LogLevel.WARN, + `The workflow you've loaded has corrupt linking data. Open ${ + new URL(location.href).origin + }/rgthree/link_fixer to try to fix.`, + ); + console[n]?.(...v); + if (CONFIG_SERVICE.getConfigValue("features.show_alerts_for_corrupt_workflows")) { + rgthree.showMessage({ + id: "bad-links", + type: "warn", + message: + "The workflow you've loaded has corrupt linking data that may be able to be fixed.", + actions: [ + { + label: "Open fixer", + href: "/rgthree/link_fixer", + }, + { + label: "Fix in place", + href: "/rgthree/link_fixer", + callback: (event) => { + event.stopPropagation(); + event.preventDefault(); + if ( + confirm( + "This will attempt to fix in place. Please make sure to have a saved copy of your workflow.", + ) + ) { + try { + const fixBadLinksResult = fixBadLinks( + graphToUse as unknown as TLGraph, + true, + ); + if (!fixBadLinksResult.hasBadLinks) { + rgthree.hideMessage("bad-links"); + alert( + "Success! It's possible some valid links may have been affected. Please check and verify your workflow.", + ); + wasLoadingAborted && app.loadGraphData(fixBadLinksResult.graph); + if ( + CONFIG_SERVICE.getConfigValue("features.monitor_for_corrupt_links") || + CONFIG_SERVICE.getConfigValue("features.monitor_bad_links") + ) { + rgthree.monitorLinkTimeout = setTimeout(() => { + rgthree.monitorBadLinks(); + }, 5000); + } + } + } catch (e) { + console.error(e); + alert("Unsuccessful at fixing corrupt data. :("); + rgthree.hideMessage("bad-links"); + } + } + }, + }, + ], + }); + } + } else if ( + CONFIG_SERVICE.getConfigValue("features.monitor_for_corrupt_links") || + CONFIG_SERVICE.getConfigValue("features.monitor_bad_links") + ) { + rgthree.monitorLinkTimeout = setTimeout(() => { + rgthree.monitorBadLinks(); + }, 5000); + } + }, 100); + return loadGraphData && loadGraphData.apply(app, [...arguments] as any); + }; + } + + /** + * [🤮] Finds a node in the currently serializing workflow from the hook setup above. This is to + * mitigate breakages from https://github.com/comfyanonymous/ComfyUI/issues/2193 we can try to + * store the workflow node so our nodes can find the seralized node. + */ + getNodeFromInitialGraphToPromptSerializedWorkflowBecauseComfyUIBrokeStuff( + node: LGraphNode, + ): SerializedLGraphNode | null { + return ( + this.initialGraphToPromptSerializedWorkflowBecauseComfyUIBrokeStuff?.nodes?.find( + (n: SerializedLGraphNode) => n.id === node.id, + ) ?? null + ); + } + + /** + * Shows a message in the UI. + */ + async showMessage(data: RgthreeUiMessage) { + let container = document.querySelector(".rgthree-top-messages-container"); + if (!container) { + container = document.createElement("div"); + container.classList.add("rgthree-top-messages-container"); + document.body.appendChild(container); + } + // If we have a dialog open then we want to append the message to the dialog so they show over + // the modal. + const dialogs = query("dialog[open]"); + if (dialogs.length) { + let dialog = dialogs[dialogs.length - 1]!; + dialog.appendChild(container); + dialog.addEventListener("close", (e) => { + document.body.appendChild(container!); + }); + } + // Hide if we exist. + await this.hideMessage(data.id); + + const messageContainer = document.createElement("div"); + messageContainer.setAttribute("type", data.type || "info"); + + const message = document.createElement("span"); + message.innerHTML = data.message; + messageContainer.appendChild(message); + + for (let a = 0; a < (data.actions || []).length; a++) { + const action = data.actions![a]!; + if (a > 0) { + const sep = document.createElement("span"); + sep.innerHTML = " | "; + messageContainer.appendChild(sep); + } + + const actionEl = document.createElement("a"); + actionEl.innerText = action.label; + if (action.href) { + actionEl.target = "_blank"; + actionEl.href = action.href; + } + if (action.callback) { + actionEl.onclick = (e) => { + return action.callback!(e); + }; + } + messageContainer.appendChild(actionEl); + } + + const messageAnimContainer = document.createElement("div"); + messageAnimContainer.setAttribute("msg-id", data.id); + messageAnimContainer.appendChild(messageContainer); + container.appendChild(messageAnimContainer); + + // Add. Wait. Measure. Wait. Anim. + await wait(64); + messageAnimContainer.style.marginTop = `-${messageAnimContainer.offsetHeight}px`; + await wait(64); + messageAnimContainer.classList.add("-show"); + + if (data.timeout) { + await wait(data.timeout); + this.hideMessage(data.id); + } + } + + /** + * Hides a message in the UI. + */ + async hideMessage(id: string) { + const msg = document.querySelector(`.rgthree-top-messages-container > [msg-id="${id}"]`); + if (msg?.classList.contains("-show")) { + msg.classList.remove("-show"); + await wait(750); + } + msg && msg.remove(); + } + + /** + * Clears all messages in the UI. + */ + async clearAllMessages() { + let container = document.querySelector(".rgthree-top-messages-container"); + container && (container.innerHTML = ""); + } + + setLogLevel(level?: LogLevel | string) { + if (typeof level === "string") { + level = LogLevelKeyToLogLevel[CONFIG_SERVICE.getConfigValue("log_level")]; + } + if (level != null) { + GLOBAL_LOG_LEVEL = level; + } + } + + logParts(level: LogLevel, message?: string, ...args: any[]) { + return this.logger.logParts(level, message, ...args); + } + + newLogSession(name?: string) { + return this.logger.newSession(name); + } + + isDevMode() { + if (window.location.href.includes("rgthree-dev=false")) { + return false; + } + return GLOBAL_LOG_LEVEL >= LogLevel.DEBUG || window.location.href.includes("rgthree-dev"); + } + + isDebugMode() { + if (!this.isDevMode() || window.location.href.includes("rgthree-debug=false")) { + return false; + } + return window.location.href.includes("rgthree-debug"); + } + + monitorBadLinks() { + const badLinksFound = fixBadLinks(app.graph); + if (badLinksFound.hasBadLinks && !this.monitorBadLinksAlerted) { + this.monitorBadLinksAlerted = true; + alert( + `Problematic links just found in live data. Can you save your workflow and file a bug with ` + + `the last few steps you took to trigger this at ` + + `https://github.com/rgthree/rgthree-comfy/issues. Thank you!`, + ); + } else if (!badLinksFound.hasBadLinks) { + // Clear the alert once fixed so we can alert again. + this.monitorBadLinksAlerted = false; + } + this.monitorLinkTimeout = setTimeout(() => { + this.monitorBadLinks(); + }, 5000); + } +} + +function getBookmarks(): ContextMenuItem[] { + const graph: TLGraph = app.graph; + + // Sorts by Title. + // I could see an option to sort by either Shortcut, Title, or Position. + const bookmarks = graph._nodes + .filter((n): n is Bookmark => n.type === NodeTypesString.BOOKMARK) + .sort((a, b) => a.title.localeCompare(b.title)) + .map((n) => ({ + content: `[${n.shortcutKey}] ${n.title}`, + className: "rgthree-contextmenu-item", + callback: () => { + n.canvasToBookmark(); + }, + })); + + return !bookmarks.length + ? [] + : [ + { + content: "🔖 Bookmarks", + disabled: true, + className: "rgthree-contextmenu-item rgthree-contextmenu-label", + }, + ...bookmarks, + ]; +} + +export const rgthree = new Rgthree(); +// Expose it on window because, why not. +(window as any).rgthree = rgthree; diff --git a/rgthree-comfy/src_web/comfyui/seed.ts b/rgthree-comfy/src_web/comfyui/seed.ts new file mode 100644 index 0000000000000000000000000000000000000000..333c3a61c2a77124bd52c2df0940a3cd25849cce --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/seed.ts @@ -0,0 +1,300 @@ +import { app } from "scripts/app.js"; +import { ComfyWidgets } from "scripts/widgets.js"; +import type { + ContextMenuItem, + IContextMenuOptions, + ContextMenu, + LGraphNode as TLGraphNode, + IWidget, + LGraphCanvas, + SerializedLGraphNode, +} from "typings/litegraph.js"; +import type { + ComfyObjectInfo, + ComfyWidget, + ComfyNodeConstructor, + ComfyApiPrompt, +} from "typings/comfy.js"; +import { RgthreeBaseServerNode } from "./base_node.js"; +import { rgthree } from "./rgthree.js"; +import { addConnectionLayoutSupport } from "./utils.js"; +import { NodeTypesString } from "./constants.js"; +import { SerializedNode } from "typings/index.js"; + +const LAST_SEED_BUTTON_LABEL = "♻️ (Use Last Queued Seed)"; + +const SPECIAL_SEED_RANDOM = -1; +const SPECIAL_SEED_INCREMENT = -2; +const SPECIAL_SEED_DECREMENT = -3; +const SPECIAL_SEEDS = [SPECIAL_SEED_RANDOM, SPECIAL_SEED_INCREMENT, SPECIAL_SEED_DECREMENT]; + +interface SeedSerializedCtx { + inputSeed?: number; + seedUsed?: number; +} + +class RgthreeSeed extends RgthreeBaseServerNode { + static override title = NodeTypesString.SEED; + static override type = NodeTypesString.SEED; + static comfyClass = NodeTypesString.SEED; + + override serialize_widgets = true; + + private logger = rgthree.newLogSession(`[Seed]`); + + static override exposedActions = ["Randomize Each Time", "Use Last Queued Seed"]; + + lastSeed?: number = undefined; + serializedCtx: SeedSerializedCtx = {}; + seedWidget!: IWidget; + lastSeedButton!: IWidget; + lastSeedValue: ComfyWidget | null = null; + + randMax = 1125899906842624; + // We can have a full range of seeds, including negative. But, for the randomRange we'll + // only generate positives, since that's what folks assume. + // const min = Math.max(-1125899906842624, this.seedWidget.options.min); + randMin = 0; + randomRange = 1125899906842624; + + private handleApiHijackingBound = this.handleApiHijacking.bind(this); + + constructor(title = RgthreeSeed.title) { + super(title); + + rgthree.addEventListener( + "comfy-api-queue-prompt-before", + this.handleApiHijackingBound as EventListener, + ); + } + + override onRemoved() { + rgthree.addEventListener( + "comfy-api-queue-prompt-before", + this.handleApiHijackingBound as EventListener, + ); + } + + override configure(info: SerializedLGraphNode): void { + super.configure(info); + if (this.properties?.["showLastSeed"]) { + this.addLastSeedValue(); + } + } + + override async handleAction(action: string) { + if (action === "Randomize Each Time") { + this.seedWidget.value = SPECIAL_SEED_RANDOM; + } else if (action === "Use Last Queued Seed") { + this.seedWidget.value = this.lastSeed != null ? this.lastSeed : this.seedWidget.value; + this.lastSeedButton.name = LAST_SEED_BUTTON_LABEL; + this.lastSeedButton.disabled = true; + } + } + + override onNodeCreated() { + super.onNodeCreated?.(); + // Grab the already available widgets, and remove the built-in control_after_generate + for (const [i, w] of this.widgets.entries()) { + if (w.name === "seed") { + this.seedWidget = w; // as ComfyWidget; + this.seedWidget.value = SPECIAL_SEED_RANDOM; + } else if (w.name === "control_after_generate") { + this.widgets.splice(i, 1); + } + } + + // Update random values in case seed comes down with different options. + let step = this.seedWidget.options.step || 1; + this.randMax = Math.min(1125899906842624, this.seedWidget.options.max); + // We can have a full range of seeds, including negative. But, for the randomRange we'll + // only generate positives, since that's what folks assume. + this.randMin = Math.max(0, this.seedWidget.options.min); + this.randomRange = (this.randMax - Math.max(0, this.randMin)) / (step / 10); + + this.addWidget( + "button", + "🎲 Randomize Each Time", + null, + () => { + this.seedWidget.value = SPECIAL_SEED_RANDOM; + }, + { serialize: false }, + ) as ComfyWidget; + + this.addWidget( + "button", + "🎲 New Fixed Random", + null, + () => { + this.seedWidget.value = + Math.floor(Math.random() * this.randomRange) * (step / 10) + this.randMin; + }, + { serialize: false }, + ); + + this.lastSeedButton = this.addWidget( + "button", + LAST_SEED_BUTTON_LABEL, + null, + () => { + this.seedWidget.value = this.lastSeed != null ? this.lastSeed : this.seedWidget.value; + this.lastSeedButton.name = LAST_SEED_BUTTON_LABEL; + this.lastSeedButton.disabled = true; + }, + { width: 50, serialize: false }, + ); + this.lastSeedButton.disabled = true; + } + + override getExtraMenuOptions(canvas: LGraphCanvas, options: ContextMenuItem[]): void { + super.getExtraMenuOptions?.apply(this, [...arguments] as any); + options.splice(options.length - 1, 0, { + content: "Show/Hide Last Seed Value", + callback: ( + _value: ContextMenuItem, + _options: IContextMenuOptions, + _event: MouseEvent, + _parentMenu: ContextMenu | undefined, + _node: TLGraphNode, + ) => { + this.properties["showLastSeed"] = !this.properties["showLastSeed"]; + if (this.properties["showLastSeed"]) { + this.addLastSeedValue(); + } else { + this.removeLastSeedValue(); + } + }, + }); + } + + addLastSeedValue() { + if (this.lastSeedValue) return; + this.lastSeedValue = ComfyWidgets["STRING"]( + this, + "last_seed", + ["STRING", { multiline: true }], + app, + ).widget; + this.lastSeedValue!.inputEl!.readOnly = true; + this.lastSeedValue!.inputEl!.style.fontSize = "0.75rem"; + this.lastSeedValue!.inputEl!.style.textAlign = "center"; + this.computeSize(); + } + + removeLastSeedValue() { + if (!this.lastSeedValue) return; + this.lastSeedValue!.inputEl!.remove(); + this.widgets.splice(this.widgets.indexOf(this.lastSeedValue as IWidget), 1); + this.lastSeedValue = null; + this.computeSize(); + } + + /** + * Intercepts the prompt right before ComfyUI sends it to the server (as fired from rgthree) so we + * can inspect the prompt and workflow data and change swap in the seeds. + * + * Note, the original implementation tried to change the widget value itself when the graph was + * queued (and the relied on ComfyUI serializing the data changed data) and then changing it back. + * This worked well until other extensions kept calling graphToPrompt during asynchronous + * operations within, causing the widget to get confused without a reliable state to reflect upon. + */ + handleApiHijacking(e: CustomEvent) { + // Don't do any work if we're muted/bypassed. + if (this.mode === LiteGraph.NEVER || this.mode === 4) { + return; + } + + const workflow = e.detail.workflow; + const output = e.detail.output; + + let workflowNode = workflow?.nodes?.find((n: SerializedNode) => n.id === this.id) ?? null; + let outputInputs = output?.[this.id]?.inputs; + + if ( + !workflowNode || + !outputInputs || + outputInputs[this.seedWidget.name || "seed"] === undefined + ) { + const [n, v] = this.logger.warnParts( + `Node ${this.id} not found in prompt data sent to server. This may be fine if only ` + + `queuing part of the workflow. If not, then this could be a bug.`, + ); + console[n]?.(...v); + return; + } + + const seedToUse = this.getSeedToUse(); + const seedWidgetndex = this.widgets.indexOf(this.seedWidget); + + workflowNode.widgets_values![seedWidgetndex] = seedToUse; + outputInputs[this.seedWidget.name || "seed"] = seedToUse; + + this.lastSeed = seedToUse; + if (seedToUse != this.seedWidget.value) { + this.lastSeedButton.name = `♻️ ${this.lastSeed}`; + this.lastSeedButton.disabled = false; + } else { + this.lastSeedButton.name = LAST_SEED_BUTTON_LABEL; + this.lastSeedButton.disabled = true; + } + if (this.lastSeedValue) { + this.lastSeedValue.value = `Last Seed: ${this.lastSeed}`; + } + } + + /** + * Determines a seed to use depending on the seed widget's current value and the last used seed. + * There are no sideffects to calling this method. + */ + private getSeedToUse() { + const inputSeed: number = this.seedWidget.value; + let seedToUse: number | null = null; + + // If our input seed was a special seed, then handle it. + if (SPECIAL_SEEDS.includes(inputSeed)) { + // If the last seed was not a special seed and we have increment/decrement, then do that on + // the last seed. + if (typeof this.lastSeed === "number" && !SPECIAL_SEEDS.includes(this.lastSeed)) { + if (inputSeed === SPECIAL_SEED_INCREMENT) { + seedToUse = this.lastSeed + 1; + } else if (inputSeed === SPECIAL_SEED_DECREMENT) { + seedToUse = this.lastSeed - 1; + } + } + // If we don't have a seed to use, or it's special seed (like we incremented into one), then + // we randomize. + if (seedToUse == null || SPECIAL_SEEDS.includes(seedToUse)) { + seedToUse = + Math.floor(Math.random() * this.randomRange) * + ((this.seedWidget.options.step || 1) / 10) + + this.randMin; + } + } + + return seedToUse ?? inputSeed; + } + + static override setUp(comfyClass: ComfyNodeConstructor, nodeData: ComfyObjectInfo) { + RgthreeBaseServerNode.registerForOverride(comfyClass, nodeData, RgthreeSeed); + } + + static override onRegisteredForOverride(comfyClass: any, ctxClass: any) { + addConnectionLayoutSupport(RgthreeSeed, app, [ + ["Left", "Right"], + ["Right", "Left"], + ]); + setTimeout(() => { + RgthreeSeed.category = comfyClass.category; + }); + } +} + +app.registerExtension({ + name: "rgthree.Seed", + async beforeRegisterNodeDef(nodeType: ComfyNodeConstructor, nodeData: ComfyObjectInfo) { + if (nodeData.name === RgthreeSeed.type) { + RgthreeSeed.setUp(nodeType, nodeData); + } + }, +}); diff --git a/rgthree-comfy/src_web/comfyui/services/bookmarks_services.ts b/rgthree-comfy/src_web/comfyui/services/bookmarks_services.ts new file mode 100644 index 0000000000000000000000000000000000000000..705f263e397808fb947def7ed725bc60fffcac3c --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/services/bookmarks_services.ts @@ -0,0 +1,18 @@ +import type { Bookmark } from "../bookmark.js"; + +import { app } from "scripts/app.js"; +import { NodeTypesString } from "../constants.js"; + +class BookmarksService { + /** + * Gets a list of the current bookmarks within the current workflow. + */ + getCurrentBookmarks() { + return app.graph._nodes + .filter((n): n is Bookmark => n.type === NodeTypesString.BOOKMARK) + .sort((a, b) => a.title.localeCompare(b.title)); + } +} + +/** The BookmarksService singleton. */ +export const SERVICE = new BookmarksService(); diff --git a/rgthree-comfy/src_web/comfyui/services/config_service.ts b/rgthree-comfy/src_web/comfyui/services/config_service.ts new file mode 100644 index 0000000000000000000000000000000000000000..0e7176c9f7e13316e7c26b645f754481a395ff65 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/services/config_service.ts @@ -0,0 +1,40 @@ +// @ts-ignore +import { rgthreeConfig } from "rgthree/config.js"; +import { getObjectValue, setObjectValue } from "rgthree/common/shared_utils.js"; +import { rgthreeApi } from "rgthree/common/rgthree_api.js"; + +/** + * A singleton service exported as `SERVICE` to handle configuration routines. + */ +class ConfigService extends EventTarget { + getConfigValue(key: string, def?: any) { + return getObjectValue(rgthreeConfig, key, def); + } + + getFeatureValue(key: string, def?: any) { + key = "features." + key.replace(/^features\./, ""); + return getObjectValue(rgthreeConfig, key, def); + } + + /** + * Given an object of key:value changes it will send to the server and wait for a successful + * response before setting the values on the local rgthreeConfig. + */ + async setConfigValues(changed: { [key: string]: any }) { + const body = new FormData(); + body.append("json", JSON.stringify(changed)); + const response = await rgthreeApi.fetchJson("/config", { method: "POST", body }); + if (response.status === "ok") { + for (const [key, value] of Object.entries(changed)) { + setObjectValue(rgthreeConfig, key, value); + this.dispatchEvent(new CustomEvent("config-change", { detail: { key, value } })); + } + } else { + return false; + } + return true; + } +} + +/** The ConfigService singleton. */ +export const SERVICE = new ConfigService(); diff --git a/rgthree-comfy/src_web/comfyui/services/context_service.ts b/rgthree-comfy/src_web/comfyui/services/context_service.ts new file mode 100644 index 0000000000000000000000000000000000000000..f66ca1f36b7ae2a4c57ed3a2f3b082218d0b6f6d --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/services/context_service.ts @@ -0,0 +1,76 @@ +import type {DynamicContextNodeBase} from "../dynamic_context_base.js"; + +import {app} from "scripts/app.js"; +import {NodeTypesString} from "../constants.js"; +import {getConnectedOutputNodesAndFilterPassThroughs} from "../utils.js"; +import {INodeInputSlot, INodeOutputSlot, INodeSlot, LGraphNode} from "typings/litegraph.js"; + +export let SERVICE: ContextService; + +const OWNED_PREFIX = "+"; +const REGEX_PREFIX = /^[\+⚠️]\s*/; +const REGEX_EMPTY_INPUT = /^\+\s*$/; + +export function stripContextInputPrefixes(name: string) { + return name.replace(REGEX_PREFIX, ""); +} + +export function getContextOutputName(inputName: string) { + if (inputName === "base_ctx") return "CONTEXT"; + return stripContextInputPrefixes(inputName).toUpperCase(); +} + +export enum InputMutationOperation { + "UNKNOWN", + "ADDED", + "REMOVED", + "RENAMED", +} + +export type InputMutation = { + operation: InputMutationOperation; + node: DynamicContextNodeBase; + slotIndex: number; + slot: INodeSlot; +}; + +export class ContextService { + + constructor() { + if (SERVICE) { + throw new Error("ContextService was already instantiated."); + } + } + + onInputChanges(node: any, mutation: InputMutation) { + const childCtxs = getConnectedOutputNodesAndFilterPassThroughs( + node, + node, + 0, + ) as DynamicContextNodeBase[]; + for (const childCtx of childCtxs) { + childCtx.handleUpstreamMutation(mutation); + } + } + + getDynamicContextInputsData(node: DynamicContextNodeBase) { + return node + .getContextInputsList() + .map((input: INodeInputSlot, index: number) => ({ + name: stripContextInputPrefixes(input.name), + type: String(input.type), + index, + })) + .filter((i) => i.type !== "*"); + } + + getDynamicContextOutputsData(node: LGraphNode) { + return node.outputs.map((output: INodeOutputSlot, index: number) => ({ + name: stripContextInputPrefixes(output.name), + type: String(output.type), + index, + })); + } +} + +SERVICE = new ContextService(); diff --git a/rgthree-comfy/src_web/comfyui/services/fast_groups_service.ts b/rgthree-comfy/src_web/comfyui/services/fast_groups_service.ts new file mode 100644 index 0000000000000000000000000000000000000000..8fece9b78a1a29a2cb9107feb8ac8acb1956e1b5 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/services/fast_groups_service.ts @@ -0,0 +1,195 @@ +import { app } from "scripts/app.js"; +import type { BaseFastGroupsModeChanger } from "../fast_groups_muter.js"; +import { + type LGraph as TLGraph, + type LGraphCanvas as TLGraphCanvas, + LGraphGroup, + Vector4, +} from "typings/litegraph.js"; + +/** + * A service that keeps global state that can be shared by multiple FastGroupsMuter or + * FastGroupsBypasser nodes rather than calculate it on it's own. + */ +class FastGroupsService { + private msThreshold = 400; + private msLastUnsorted = 0; + private msLastAlpha = 0; + private msLastPosition = 0; + + private groupsUnsorted: LGraphGroup[] = []; + private groupsSortedAlpha: LGraphGroup[] = []; + private groupsSortedPosition: LGraphGroup[] = []; + + private readonly fastGroupNodes: BaseFastGroupsModeChanger[] = []; + + private runScheduledForMs: number | null = null; + private runScheduleTimeout: number | null = null; + private runScheduleAnimation: number | null = null; + + private cachedNodeBoundings: { [key: number]: Vector4 } | null = null; + + constructor() { + // Don't need to do anything, wait until a signal. + } + + addFastGroupNode(node: BaseFastGroupsModeChanger) { + this.fastGroupNodes.push(node); + // Schedule it because the node may not be ready to refreshWidgets (like, when added it may + // not have cloned properties to filter against, etc.). + this.scheduleRun(8); + } + + removeFastGroupNode(node: BaseFastGroupsModeChanger) { + const index = this.fastGroupNodes.indexOf(node); + if (index > -1) { + this.fastGroupNodes.splice(index, 1); + } + // If we have no more group nodes, then clear out data; it could be because of a canvas clear. + if (!this.fastGroupNodes?.length) { + this.clearScheduledRun(); + this.groupsUnsorted = []; + this.groupsSortedAlpha = []; + this.groupsSortedPosition = []; + } + } + + private run() { + // We only run if we're scheduled, so if we're not, then bail. + if (!this.runScheduledForMs) { + return; + } + for (const node of this.fastGroupNodes) { + node.refreshWidgets(); + } + this.clearScheduledRun(); + this.scheduleRun(); + } + + private scheduleRun(ms = 500) { + // If we got a request for an immediate schedule and already have on scheduled for longer, then + // cancel the long one to expediate a fast one. + if (this.runScheduledForMs && ms < this.runScheduledForMs) { + this.clearScheduledRun(); + } + if (!this.runScheduledForMs && this.fastGroupNodes.length) { + this.runScheduledForMs = ms; + this.runScheduleTimeout = setTimeout(() => { + this.runScheduleAnimation = requestAnimationFrame(() => this.run()); + }, ms); + } + } + + private clearScheduledRun() { + this.runScheduleTimeout && clearTimeout(this.runScheduleTimeout); + this.runScheduleAnimation && cancelAnimationFrame(this.runScheduleAnimation); + this.runScheduleTimeout = null; + this.runScheduleAnimation = null; + this.runScheduledForMs = null; + } + + /** + * Returns the boundings for all nodes on the graph, then clears it after a short delay. This is + * to increase efficiency by caching the nodes' boundings when multiple groups are on the page. + */ + getBoundingsForAllNodes() { + if (!this.cachedNodeBoundings) { + this.cachedNodeBoundings = {}; + for (const node of app.graph._nodes) { + this.cachedNodeBoundings[node.id] = node.getBounding(); + } + setTimeout(() => { + this.cachedNodeBoundings = null; + }, 50); + } + return this.cachedNodeBoundings; + } + + /** + * This overrides `LGraphGroup.prototype.recomputeInsideNodes` to be much more efficient when + * calculating for many groups at once (only compute all nodes once in `getBoundingsForAllNodes`). + */ + recomputeInsideNodesForGroup(group: LGraphGroup) { + const cachedBoundings = this.getBoundingsForAllNodes(); + const nodes = group.graph._nodes; + group._nodes.length = 0; + + for (const node of nodes) { + const node_bounding = cachedBoundings[node.id]; + if (!node_bounding || !LiteGraph.overlapBounding(group._bounding, node_bounding)) { + continue; + } + group._nodes.push(node); + } + } + + /** + * Everything goes through getGroupsUnsorted, so we only get groups once. However, LiteGraph's + * `recomputeInsideNodes` is inefficient when calling multiple groups (it iterates over all nodes + * each time). So, we'll do our own dang thing, once. + */ + private getGroupsUnsorted(now: number) { + const canvas = app.canvas as TLGraphCanvas; + const graph = app.graph as TLGraph; + + if ( + // Don't recalculate nodes if we're moving a group (added by ComfyUI in app.js) + !canvas.selected_group_moving && + (!this.groupsUnsorted.length || now - this.msLastUnsorted > this.msThreshold) + ) { + this.groupsUnsorted = [...graph._groups]; + for (const group of this.groupsUnsorted) { + this.recomputeInsideNodesForGroup(group); + (group as any)._rgthreeHasAnyActiveNode = group._nodes.some( + (n) => n.mode === LiteGraph.ALWAYS, + ); + } + this.msLastUnsorted = now; + } + return this.groupsUnsorted; + } + + private getGroupsAlpha(now: number) { + const graph = app.graph as TLGraph; + if (!this.groupsSortedAlpha.length || now - this.msLastAlpha > this.msThreshold) { + this.groupsSortedAlpha = [...this.getGroupsUnsorted(now)].sort((a, b) => { + return a.title.localeCompare(b.title); + }); + this.msLastAlpha = now; + } + return this.groupsSortedAlpha; + } + + private getGroupsPosition(now: number) { + const graph = app.graph as TLGraph; + if (!this.groupsSortedPosition.length || now - this.msLastPosition > this.msThreshold) { + this.groupsSortedPosition = [...this.getGroupsUnsorted(now)].sort((a, b) => { + // Sort by y, then x, clamped to 30. + const aY = Math.floor(a._pos[1] / 30); + const bY = Math.floor(b._pos[1] / 30); + if (aY == bY) { + const aX = Math.floor(a._pos[0] / 30); + const bX = Math.floor(b._pos[0] / 30); + return aX - bX; + } + return aY - bY; + }); + this.msLastPosition = now; + } + return this.groupsSortedPosition; + } + + getGroups(sort?: string) { + const now = +new Date(); + if (sort === "alphanumeric") { + return this.getGroupsAlpha(now); + } + if (sort === "position") { + return this.getGroupsPosition(now); + } + return this.getGroupsUnsorted(now); + } +} + +/** The FastGroupsService singleton. */ +export const SERVICE = new FastGroupsService(); diff --git a/rgthree-comfy/src_web/comfyui/services/key_events_services.ts b/rgthree-comfy/src_web/comfyui/services/key_events_services.ts new file mode 100644 index 0000000000000000000000000000000000000000..9514dafa33009f3de54bac87aec791b8101d5615 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/services/key_events_services.ts @@ -0,0 +1,169 @@ +/** + * A service responsible for captruing keys within LiteGraph's canvas, and outside of it, allowing + * nodes and other services to confidently determine what's going on. + */ +class KeyEventService extends EventTarget { + readonly downKeys: { [key: string]: boolean } = {}; + + ctrlKey = false; + altKey = false; + metaKey = false; + shiftKey = false; + + private readonly isMac: boolean = !!( + navigator.platform?.toLocaleUpperCase().startsWith("MAC") || + (navigator as any).userAgentData?.platform?.toLocaleUpperCase().startsWith("MAC") + ); + + constructor() { + super(); + this.initialize(); + } + + initialize() { + const that = this; + // [🤮] Sometimes ComfyUI and/or LiteGraph stop propagation of key events which makes it hard + // to determine if keys are currently pressed. To attempt to get around this, we'll hijack + // LiteGraph's processKey to try to get better consistency. + const processKey = LGraphCanvas.prototype.processKey; + LGraphCanvas.prototype.processKey = function (e: KeyboardEvent) { + if (e.type === "keydown" || e.type === "keyup") { + that.handleKeyDownOrUp(e); + } + return processKey.apply(this, [...arguments] as any) as any; + }; + + // Now that ComfyUI has more non-canvas UI (like the top bar), we listen on window as well, and + // de-dupe when we get multiple events from both window and/or LiteGraph. + window.addEventListener("keydown", (e) => { + that.handleKeyDownOrUp(e); + }); + window.addEventListener("keyup", (e) => { + that.handleKeyDownOrUp(e); + }); + + // If we get a visibilitychange, then clear the keys since we can't listen for keys up/down when + // not visible. + document.addEventListener("visibilitychange", (e) => { + this.clearKeydowns(); + }); + + // If we get a blur, then also clear the keys since we can't listen for keys up/down when + // blurred. This can happen w/o a visibilitychange, like a browser alert. + window.addEventListener("blur", (e) => { + this.clearKeydowns(); + }); + } + + /** + * Adds a new queue item, unless the last is the same. + */ + handleKeyDownOrUp(e: KeyboardEvent) { + const key = e.key.toLocaleUpperCase(); + // If we're already down, or already up, then ignore and don't fire. + if ((e.type === 'keydown' && this.downKeys[key] === true) + || (e.type === 'keyup' && this.downKeys[key] === undefined)) { + return; + } + + this.ctrlKey = !!e.ctrlKey; + this.altKey = !!e.altKey; + this.metaKey = !!e.metaKey; + this.shiftKey = !!e.shiftKey; + if (e.type === "keydown") { + this.downKeys[key] = true; + this.dispatchCustomEvent("keydown", { originalEvent: e }); + } else if (e.type === "keyup") { + // See https://github.com/rgthree/rgthree-comfy/issues/238 + // A little bit of a hack, but Mac reportedly does something odd with copy/paste. ComfyUI + // gobbles the copy event propagation, but it happens for paste too and reportedly 'Enter' which + // I can't find a reason for in LiteGraph/comfy. So, for Mac only, whenever we lift a Command + // (META) key, we'll also clear any other keys. + if (key === "META" && this.isMac) { + this.clearKeydowns(); + } else { + delete this.downKeys[key]; + // this.debugRenderKeys(); + } + this.dispatchCustomEvent("keyup", { originalEvent: e }); + } + + } + + private clearKeydowns() { + this.ctrlKey = false; + this.altKey = false; + this.metaKey = false; + this.shiftKey = false; + for (const key in this.downKeys) delete this.downKeys[key]; + } + + /** + * Wraps `dispatchEvent` for easier CustomEvent dispatching. + */ + private dispatchCustomEvent(event: string, detail?: any) { + if (detail != null) { + return this.dispatchEvent(new CustomEvent(event, { detail })); + } + return this.dispatchEvent(new CustomEvent(event)); + } + + /** + * Parses a shortcut string. + * + * - 's' => ['S'] + * - 'shift + c' => ['SHIFT', 'C'] + * - 'shift + meta + @' => ['SHIFT', 'META', '@'] + * - 'shift + + + @' => ['SHIFT', '__PLUS__', '='] + * - '+ + p' => ['__PLUS__', 'P'] + */ + private getKeysFromShortcut(shortcut: string | string[]) { + let keys; + if (typeof shortcut === "string") { + // Rip all spaces out. Note, Comfy swallows space, so we don't have to handle it. Otherwise, + // we would require space to be fed as "Space" or "Spacebar" instead of " ". + shortcut = shortcut.replace(/\s/g, ""); + // Change a real "+" to something we can encode. + shortcut = shortcut.replace(/^\+/, "__PLUS__").replace(/\+\+/, "+__PLUS__"); + keys = shortcut.split("+").map((i) => i.replace("__PLUS__", "+")); + } else { + keys = [...shortcut]; + } + return keys.map((k) => k.toLocaleUpperCase()); + } + + /** + * Checks if all keys passed in are down. + */ + areAllKeysDown(keys: string | string[]) { + keys = this.getKeysFromShortcut(keys); + return keys.every((k) => { + return this.downKeys[k]; + }); + } + + /** + * Checks if only the keys passed in are down; optionally and additionally allowing "shift" key. + */ + areOnlyKeysDown(keys: string | string[], alsoAllowShift = false) { + keys = this.getKeysFromShortcut(keys); + const allKeysDown = this.areAllKeysDown(keys); + const downKeysLength = Object.values(this.downKeys).length; + // All keys are down and they're the only ones. + if (allKeysDown && keys.length === downKeysLength) { + return true; + } + // Special case allowing the shift key in addition to the shortcut keys. This helps when a user + // may had originally defined "$" as a shortcut, but needs to press "shift + $" since it's an + // upper key character, etc. + if (alsoAllowShift && !keys.includes("SHIFT") && keys.length === downKeysLength - 1) { + // If we're holding down shift, have one extra key held down, and the original keys don't + // include shift, then we're good to go. + return allKeysDown && this.areAllKeysDown(["SHIFT"]); + } + return false; + } +} + +/** The KeyEventService singleton. */ +export const SERVICE = new KeyEventService(); diff --git a/rgthree-comfy/src_web/comfyui/testing/comfyui_env.ts b/rgthree-comfy/src_web/comfyui/testing/comfyui_env.ts new file mode 100644 index 0000000000000000000000000000000000000000..af186ce76bffeffc4260e247208c06b79afbd6d6 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/testing/comfyui_env.ts @@ -0,0 +1,67 @@ +import { app } from "scripts/app.js"; +import { NodeTypesString } from "../constants.js"; +import { wait } from "rgthree/common/shared_utils.js"; +import type { LGraphNode } from "typings/litegraph.js"; + +type addNodeOptions = { + placement?: string; +}; + +/** + * A testing environment to make setting up, clearing, and queuing more predictable in an + * integration test environment. + */ +export class ComfyUITestEnvironment { + private lastNode: LGraphNode | null = null; + private maxY = 0; + + constructor() {} + + async addNode(nodeString: string, options: addNodeOptions = {}) { + const [canvas, graph] = [app.canvas, app.graph]; + const node = LiteGraph.createNode(nodeString); + let x = 0; + let y = 30; + if (this.lastNode) { + const placement = options.placement || "right"; + if (placement === "under") { + x = this.lastNode.pos[0]; + y = this.lastNode.pos[1] + this.lastNode.size[1] + 30; + } else if (placement === "right") { + x = this.lastNode.pos[0] + this.lastNode.size[0] + 100; + y = this.lastNode.pos[1]; + } else if (placement === "start") { + x = 0; + y = this.maxY + 50; + } + } + canvas.graph.add(node); + node.pos = [x, y]; + canvas.selectNode(node); + app.graph.setDirtyCanvas(true, true); + await wait(); + this.lastNode = node; + this.maxY = Math.max(this.maxY, y + this.lastNode.size[1]); + return (this.lastNode = node); + } + + async clear() { + app.clean(); + app.graph.clear(); + const nodeConfig = await this.addNode(NodeTypesString.KSAMPLER_CONFIG); + const displayAny = await this.addNode(NodeTypesString.DISPLAY_ANY); + nodeConfig.widgets[0]!.value = Math.round(Math.random() * 100); + nodeConfig.connect(0, displayAny, 0); + await this.queuePrompt(); + app.clean(); + app.graph.clear(); + this.lastNode = null; + this.maxY = 0; + await wait(); + } + + async queuePrompt() { + await app.queuePrompt(); + await wait(150); + } +} diff --git a/rgthree-comfy/src_web/comfyui/testing/runner.ts b/rgthree-comfy/src_web/comfyui/testing/runner.ts new file mode 100644 index 0000000000000000000000000000000000000000..a5be07da1fb24a4aeac279b7e2727990ce8923e1 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/testing/runner.ts @@ -0,0 +1,133 @@ +/** + * @fileoverview A set of methods that mimic a bit of the Jasmine testing library, but simpler and + * more succinct for manipulating a comfy integration test. + */ +import { wait } from "rgthree/common/shared_utils.js"; + +type TestContext = { + label?: string; + beforeEach?: Function[]; +}; + +let contexts: TestContext[] = []; + +export function describe(label: string, fn: Function) { + return async () => { + await describeRun(label, fn); + }; +} + +export async function describeRun(label: string, fn: Function) { + await wait(); + contexts.push({ label }); + console.group(`[Start] ${contexts[contexts.length - 1]!.label}`); + await fn(); + contexts.pop(); + console.groupEnd(); +} + +export async function should(declaration: string, fn: Function) { + if (!contexts[contexts.length - 1]) { + throw Error("Called should outside of a describe."); + } + console.group(`...should ${declaration}`); + try { + for (const context of contexts) { + for (const beforeEachFn of context?.beforeEach || []) { + await beforeEachFn(); + } + } + await fn(); + } catch (e: any) { + fail(e); + } + console.groupEnd(); +} + +export async function beforeEach(fn: Function) { + if (!contexts[contexts.length - 1]) { + throw Error("Called beforeEach outside of a describe."); + } + const last = contexts[contexts.length - 1]!; + last.beforeEach = last?.beforeEach || []; + last.beforeEach.push(fn); +} + +export function fail(e: Error) { + log(`X Failure: ${e}`, "color:#600; background:#fdd; padding: 2px 6px;"); +} + +function log(msg: string, styles: string) { + if (styles) { + console.log(`%c ${msg}`, styles); + } else { + console.log(msg); + } +} + +class Expectation { + private propertyLabel: string | null = ""; + private expectedLabel: string | null = ""; + private expectedFn!: (v: any) => boolean; + private value: any; + + constructor(value: any) { + this.value = value; + } + + toBe(labelOrExpected: any, maybeExpected?: any) { + const expected = maybeExpected !== undefined ? maybeExpected : labelOrExpected; + this.propertyLabel = maybeExpected !== undefined ? labelOrExpected : null; + this.expectedLabel = JSON.stringify(expected); + this.expectedFn = (v) => v == expected; + return this.toBeEval(); + } + toBeUndefined(propertyLabel: string) { + this.expectedFn = (v) => v === undefined; + this.propertyLabel = propertyLabel || ""; + this.expectedLabel = "undefined"; + return this.toBeEval(true); + } + toBeNullOrUndefined(propertyLabel: string) { + this.expectedFn = (v) => v == null; + this.propertyLabel = propertyLabel || ""; + this.expectedLabel = "null or undefined"; + return this.toBeEval(true); + } + toBeTruthy(propertyLabel: string) { + this.expectedFn = (v) => !v; + this.propertyLabel = propertyLabel || ""; + this.expectedLabel = "truthy"; + return this.toBeEval(false); + } + toBeANumber(propertyLabel: string) { + this.expectedFn = (v) => typeof v === "number"; + this.propertyLabel = propertyLabel || ""; + this.expectedLabel = "a number"; + return this.toBeEval(); + } + toBeEval(strict = false) { + let evaluation = this.expectedFn(this.value); + let msg = `Expected ${this.propertyLabel ? this.propertyLabel + " to be " : ""}${ + this.expectedLabel + }`; + msg += evaluation ? "." : `, but was ${JSON.stringify(this.value)}`; + this.log(evaluation, msg); + return evaluation; + } + log(value: boolean, msg: string) { + if (value) { + log(`🗸 ${msg}`, "color:#060; background:#cec; padding: 2px 6px;"); + } else { + log(`X ${msg}`, "color:#600; background:#fdd; padding: 2px 6px;"); + } + } +} + +export function expect(value: any, msg?: string) { + const expectation = new Expectation(value); + if (msg) { + expectation.log(value, msg); + } + return expectation; +} diff --git a/rgthree-comfy/src_web/comfyui/tests/context_dynamic_tests.ts b/rgthree-comfy/src_web/comfyui/tests/context_dynamic_tests.ts new file mode 100644 index 0000000000000000000000000000000000000000..0be20a2715bd2c16e4e2f16310a3cf95b36e606a --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/tests/context_dynamic_tests.ts @@ -0,0 +1,191 @@ +import type { + LiteGraph as TLiteGraph, + LGraphCanvas as TLGraphCanvas, + LGraph as TLGraph, + LGraphNode as TLGraphNode, + Vector2, + LGraphNode, +} from "typings/litegraph.js"; +import {rgthree} from "../rgthree.js"; +import {NodeTypesString} from "../constants.js"; +import {wait} from "rgthree/common/shared_utils.js"; +import {describe, should, beforeEach, expect, describeRun} from "../testing/runner.js"; +import {ComfyUITestEnvironment} from "../testing/comfyui_env.js"; + +declare const LiteGraph: typeof TLiteGraph; + +const env = new ComfyUITestEnvironment(); + +function verifyInputAndOutputName( + node: LGraphNode, + index: number, + inputName: string | null, + isLinked?: boolean, +) { + if (inputName != null) { + expect(node.inputs[index]!.name).toBe(`input ${index} name`, inputName); + } + if (isLinked) { + expect(node.inputs[index]!.link).toBeANumber(`input ${index} connection`); + } else if (isLinked === false) { + expect(node.inputs[index]!.link).toBeNullOrUndefined(`input ${index} connection`); + } + if (inputName != null) { + if (inputName === "+") { + expect(node.outputs[index]).toBeUndefined(`output ${index}`); + } else { + let outputName = + inputName === "base_ctx" ? "CONTEXT" : inputName.replace(/^\+\s/, "").toUpperCase(); + expect(node.outputs[index]!.name).toBe(`output ${index} name`, outputName); + } + } +} + +function vertifyInputsStructure(node: LGraphNode, expectedLength: number) { + expect(node.inputs.length).toBe("inputs length", expectedLength); + expect(node.outputs.length).toBe("outputs length", expectedLength - 1); + verifyInputAndOutputName(node, expectedLength - 1, "+", false); +} + +(window as any).rgthree_tests = (window as any).rgthree_tests || {}; +(window as any).rgthree_tests.test_dynamic_context = describe("ContextDynamicTest", async () => { + let nodeConfig!: TLGraphNode; + let nodeCtx!: TLGraphNode; + + let lastNode: LGraphNode | null = null; + + await beforeEach(async () => { + await env.clear(); + lastNode = nodeConfig = await env.addNode(NodeTypesString.KSAMPLER_CONFIG); + lastNode = nodeCtx = await env.addNode(NodeTypesString.DYNAMIC_CONTEXT); + nodeConfig.connect(0, nodeCtx, 1); // steps + nodeConfig.connect(2, nodeCtx, 2); // cfg + nodeConfig.connect(4, nodeCtx, 3); // scheduler + nodeConfig.connect(0, nodeCtx, 4); // This is the step.1 + nodeConfig.connect(0, nodeCtx, 5); // This is the step.2 + nodeCtx.disconnectInput(2); + nodeCtx.disconnectInput(5); + nodeConfig.connect(0, nodeCtx, 6); // This is the step.3 + nodeCtx.disconnectInput(6); + await wait(); + }); + + await should("add correct inputs", async () => { + vertifyInputsStructure(nodeCtx, 8); + let i = 0; + verifyInputAndOutputName(nodeCtx, i++, "base_ctx", false); + verifyInputAndOutputName(nodeCtx, i++, "+ steps", true); + verifyInputAndOutputName(nodeCtx, i++, "+ cfg", false); + verifyInputAndOutputName(nodeCtx, i++, "+ scheduler", true); + verifyInputAndOutputName(nodeCtx, i++, "+ steps.1", true); + verifyInputAndOutputName(nodeCtx, i++, "+ steps.2", false); + verifyInputAndOutputName(nodeCtx, i++, "+ steps.3", false); + }); + + await should("add evaluate correct outputs", async () => { + const displayAny1 = await env.addNode(NodeTypesString.DISPLAY_ANY, {placement: "right"}); + const displayAny2 = await env.addNode(NodeTypesString.DISPLAY_ANY, {placement: "under"}); + const displayAny3 = await env.addNode(NodeTypesString.DISPLAY_ANY, {placement: "under"}); + const displayAny4 = await env.addNode(NodeTypesString.DISPLAY_ANY, {placement: "under"}); + + nodeCtx.connect(1, displayAny1, 0); // steps + nodeCtx.connect(3, displayAny2, 0); // scheduler + nodeCtx.connect(4, displayAny3, 0); // steps.1 + nodeCtx.connect(6, displayAny4, 0); // steps.3 (unlinked) + + await env.queuePrompt(); + + expect(displayAny1.widgets![0]!.value).toBe("output 1", 30); + expect(displayAny2.widgets![0]!.value).toBe("output 3", '"normal"'); + expect(displayAny3.widgets![0]!.value).toBe("output 4", 30); + expect(displayAny4.widgets![0]!.value).toBe("output 6", "None"); + }); + + await describeRun("Nested", async () => { + let nodeConfig2!: TLGraphNode; + let nodeCtx2!: TLGraphNode; + + await beforeEach(async () => { + nodeConfig2 = await env.addNode(NodeTypesString.KSAMPLER_CONFIG, {placement: "start"}); + nodeConfig2.widgets[0]!.value = 111; + nodeConfig2.widgets[2]!.value = 11.1; + nodeCtx2 = await env.addNode(NodeTypesString.DYNAMIC_CONTEXT, {placement: "right"}); + nodeConfig2.connect(0, nodeCtx2, 1); // steps + nodeConfig2.connect(2, nodeCtx2, 2); // cfg + nodeConfig2.connect(3, nodeCtx2, 3); // sampler + nodeConfig2.connect(2, nodeCtx2, 4); // This is the cfg.1 + nodeConfig2.connect(0, nodeCtx2, 5); // This is the steps.1 + nodeCtx2.disconnectInput(2); + nodeCtx2.disconnectInput(5); + nodeConfig2.connect(2, nodeCtx2, 6); // This is the cfg.2 + nodeCtx2.disconnectInput(6); + + await wait(); + }); + + await should("disallow context node to be connected to non-first spot.", async () => { + // Connect to first node. + let expectedInputs = 8; + + nodeCtx2.connect(0, nodeCtx, expectedInputs - 1); + console.log(nodeCtx.inputs); + + vertifyInputsStructure(nodeCtx, expectedInputs); + verifyInputAndOutputName(nodeCtx, 0, "base_ctx", false); + verifyInputAndOutputName(nodeCtx, nodeCtx.inputs.length - 1, null, false); + + nodeCtx2.connect(0, nodeCtx, 0); + expectedInputs = 14; + vertifyInputsStructure(nodeCtx, expectedInputs); + verifyInputAndOutputName(nodeCtx, 0, "base_ctx", true); + verifyInputAndOutputName(nodeCtx, expectedInputs - 1, null, false); + }); + + await should("add inputs from connected above owned.", async () => { + // Connect to first node. + nodeCtx2.connect(0, nodeCtx, 0); + + let expectedInputs = 14; + vertifyInputsStructure(nodeCtx, expectedInputs); + let i = 0; + verifyInputAndOutputName(nodeCtx, i++, "base_ctx", true); + verifyInputAndOutputName(nodeCtx, i++, "steps", false); + verifyInputAndOutputName(nodeCtx, i++, "cfg", false); + verifyInputAndOutputName(nodeCtx, i++, "sampler", false); + verifyInputAndOutputName(nodeCtx, i++, "cfg.1", false); + verifyInputAndOutputName(nodeCtx, i++, "steps.1", false); + verifyInputAndOutputName(nodeCtx, i++, "cfg.2", false); + verifyInputAndOutputName(nodeCtx, i++, "+ steps.2", true); + verifyInputAndOutputName(nodeCtx, i++, "+ cfg.3", false); + verifyInputAndOutputName(nodeCtx, i++, "+ scheduler", true); + verifyInputAndOutputName(nodeCtx, i++, "+ steps.3", true); + verifyInputAndOutputName(nodeCtx, i++, "+ steps.4", false); + verifyInputAndOutputName(nodeCtx, i++, "+ steps.5", false); + verifyInputAndOutputName(nodeCtx, i++, "+", false); + }); + + await should("add then remove inputs when disconnected.", async () => { + // Connect to first node. + nodeCtx2.connect(0, nodeCtx, 0); + + let expectedInputs = 14; + expect(nodeCtx.inputs.length).toBe("inputs length", expectedInputs); + expect(nodeCtx.outputs.length).toBe("outputs length", expectedInputs - 1); + + nodeCtx.disconnectInput(0); + + expectedInputs = 8; + expect(nodeCtx.inputs.length).toBe("inputs length", expectedInputs); + expect(nodeCtx.outputs.length).toBe("outputs length", expectedInputs - 1); + let i = 0; + verifyInputAndOutputName(nodeCtx, i++, "base_ctx", false); + verifyInputAndOutputName(nodeCtx, i++, "+ steps", true); + verifyInputAndOutputName(nodeCtx, i++, "+ cfg", false); + verifyInputAndOutputName(nodeCtx, i++, "+ scheduler", true); + verifyInputAndOutputName(nodeCtx, i++, "+ steps.1", true); + verifyInputAndOutputName(nodeCtx, i++, "+ steps.2", false); + verifyInputAndOutputName(nodeCtx, i++, "+ steps.3", false); + verifyInputAndOutputName(nodeCtx, i++, "+", false); + }); + }); +}); diff --git a/rgthree-comfy/src_web/comfyui/utils.ts b/rgthree-comfy/src_web/comfyui/utils.ts new file mode 100644 index 0000000000000000000000000000000000000000..ec783095593178fd20e7aea96d8c54600c79b3a7 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/utils.ts @@ -0,0 +1,927 @@ +import type { ComfyApp, ComfyNodeConstructor, ComfyObjectInfo } from "typings/comfy.js"; +import type { + Vector2, + LGraphCanvas, + ContextMenuItem, + LLink, + LGraph, + IContextMenuOptions, + ContextMenu, + LGraphNode, + INodeSlot, + INodeInputSlot, + INodeOutputSlot, +} from "typings/litegraph.js"; +import type { Constructor } from "typings/index.js"; +import { app } from "scripts/app.js"; +import { api } from "scripts/api.js"; +import { Resolver, getResolver, wait } from "rgthree/common/shared_utils.js"; +import { RgthreeHelpDialog } from "rgthree/common/dialog.js"; + +/** + * Override the api.getNodeDefs call to add a hook for refreshing node defs. + * This is necessary for power prompt's custom combos. Since API implements + * add/removeEventListener already, this is rather trivial. + */ +const oldApiGetNodeDefs = api.getNodeDefs; +api.getNodeDefs = async function () { + const defs = await oldApiGetNodeDefs.call(api); + this.dispatchEvent(new CustomEvent("fresh-node-defs", { detail: defs })); + return defs; +}; + +export enum IoDirection { + INPUT, + OUTPUT, +} + +const PADDING = 0; + +type LiteGraphDir = + | typeof LiteGraph.LEFT + | typeof LiteGraph.RIGHT + | typeof LiteGraph.UP + | typeof LiteGraph.DOWN; +export const LAYOUT_LABEL_TO_DATA: { [label: string]: [LiteGraphDir, Vector2, Vector2] } = { + Left: [LiteGraph.LEFT, [0, 0.5], [PADDING, 0]], + Right: [LiteGraph.RIGHT, [1, 0.5], [-PADDING, 0]], + Top: [LiteGraph.UP, [0.5, 0], [0, PADDING]], + Bottom: [LiteGraph.DOWN, [0.5, 1], [0, -PADDING]], +}; +export const LAYOUT_LABEL_OPPOSITES: { [label: string]: string } = { + Left: "Right", + Right: "Left", + Top: "Bottom", + Bottom: "Top", +}; +export const LAYOUT_CLOCKWISE = ["Top", "Right", "Bottom", "Left"]; + +interface MenuConfig { + name: string | ((node: LGraphNode) => string); + property?: string; + prepareValue?: (value: string, node: LGraphNode) => any; + callback?: (node: LGraphNode, value?: string) => void; + subMenuOptions?: (string | null)[] | ((node: LGraphNode) => (string | null)[]); +} + +export function addMenuItem( + node: Constructor, + _app: ComfyApp, + config: MenuConfig, + after = "Shape", +) { + const oldGetExtraMenuOptions = node.prototype.getExtraMenuOptions; + node.prototype.getExtraMenuOptions = function ( + canvas: LGraphCanvas, + menuOptions: ContextMenuItem[], + ) { + oldGetExtraMenuOptions && oldGetExtraMenuOptions.apply(this, [canvas, menuOptions]); + addMenuItemOnExtraMenuOptions(this, config, menuOptions, after); + }; +} + +/** + * Waits for the canvas to be available on app using a single promise. + */ +let canvasResolver: Resolver | null = null; +export function waitForCanvas() { + if (canvasResolver === null) { + canvasResolver = getResolver(); + function _waitForCanvas() { + if (!canvasResolver!.completed) { + if (app?.canvas) { + canvasResolver!.resolve(app.canvas); + } else { + requestAnimationFrame(_waitForCanvas); + } + } + } + _waitForCanvas(); + } + return canvasResolver.promise; +} + +/** + * Waits for the graph to be available on app using a single promise. + */ +let graphResolver: Resolver | null = null; +export function waitForGraph() { + if (graphResolver === null) { + graphResolver = getResolver(); + function _wait() { + if (!graphResolver!.completed) { + if (app?.graph) { + graphResolver!.resolve(app.graph); + } else { + requestAnimationFrame(_wait); + } + } + } + _wait(); + } + return graphResolver.promise; +} + +export function addMenuItemOnExtraMenuOptions( + node: LGraphNode, + config: MenuConfig, + menuOptions: ContextMenuItem[], + after = "Shape", +) { + let idx = menuOptions + .slice() + .reverse() + .findIndex((option) => (option as any)?.isRgthree); + if (idx == -1) { + idx = menuOptions.findIndex((option) => option?.content?.includes(after)) + 1; + if (!idx) { + idx = menuOptions.length - 1; + } + // Add a separator, and move to the next one. + menuOptions.splice(idx, 0, null); + idx++; + } else { + idx = menuOptions.length - idx; + } + + const subMenuOptions = + typeof config.subMenuOptions === "function" + ? config.subMenuOptions(node) + : config.subMenuOptions; + + menuOptions.splice(idx, 0, { + content: typeof config.name == "function" ? config.name(node) : config.name, + has_submenu: !!subMenuOptions?.length, + isRgthree: true, // Mark it, so we can find it. + callback: ( + value: ContextMenuItem, + _options: IContextMenuOptions, + event: MouseEvent, + parentMenu: ContextMenu | undefined, + _node: LGraphNode, + ) => { + if (!!subMenuOptions?.length) { + new LiteGraph.ContextMenu( + subMenuOptions.map((option) => (option ? { content: option } : null)), + { + event, + parentMenu, + callback: ( + subValue: ContextMenuItem, + _options: IContextMenuOptions, + _event: MouseEvent, + _parentMenu: ContextMenu | undefined, + _node: LGraphNode, + ) => { + if (config.property) { + node.properties = node.properties || {}; + node.properties[config.property] = config.prepareValue + ? config.prepareValue(subValue!.content || '', node) + : subValue!.content || ''; + } + config.callback && config.callback(node, subValue?.content); + }, + }, + ); + return; + } + if (config.property) { + node.properties = node.properties || {}; + node.properties[config.property] = config.prepareValue + ? config.prepareValue(node.properties[config.property], node) + : !node.properties[config.property]; + } + config.callback && config.callback(node, value?.content); + }, + } as ContextMenuItem); +} + +export function addConnectionLayoutSupport( + node: Constructor, + app: ComfyApp, + options = [ + ["Left", "Right"], + ["Right", "Left"], + ], + callback?: (node: LGraphNode) => void, +) { + addMenuItem(node, app, { + name: "Connections Layout", + property: "connections_layout", + subMenuOptions: options.map((option) => option[0] + (option[1] ? " -> " + option[1] : "")), + prepareValue: (value, node) => { + const values = value.split(" -> "); + if (!values[1] && !node.outputs?.length) { + values[1] = LAYOUT_LABEL_OPPOSITES[values[0]!]!; + } + if (!LAYOUT_LABEL_TO_DATA[values[0]!] || !LAYOUT_LABEL_TO_DATA[values[1]!]) { + throw new Error(`New Layout invalid: [${values[0]}, ${values[1]}]`); + } + return values; + }, + callback: (node) => { + callback && callback(node); + app.graph.setDirtyCanvas(true, true); + }, + }); + + // const oldGetConnectionPos = node.prototype.getConnectionPos; + node.prototype.getConnectionPos = function (isInput: boolean, slotNumber: number, out: Vector2) { + // Purposefully do not need to call the old one. + // oldGetConnectionPos && oldGetConnectionPos.apply(this, [isInput, slotNumber, out]); + return getConnectionPosForLayout(this, isInput, slotNumber, out); + }; +} + +export function setConnectionsLayout(node: LGraphNode, newLayout: [string, string]) { + newLayout = newLayout || (node as any).defaultConnectionsLayout || ["Left", "Right"]; + // If we didn't supply an output layout, and there's no outputs, then just choose the opposite of the + // input as a safety. + if (!newLayout[1] && !node.outputs?.length) { + newLayout[1] = LAYOUT_LABEL_OPPOSITES[newLayout[0]!]!; + } + if (!LAYOUT_LABEL_TO_DATA[newLayout[0]] || !LAYOUT_LABEL_TO_DATA[newLayout[1]]) { + throw new Error(`New Layout invalid: [${newLayout[0]}, ${newLayout[1]}]`); + } + node.properties = node.properties || {}; + node.properties["connections_layout"] = newLayout; +} + +/** Allows collapsing of connections into one. Pretty unusable, unless you're the muter. */ +export function setConnectionsCollapse( + node: LGraphNode, + collapseConnections: boolean | null = null, +) { + node.properties = node.properties || {}; + collapseConnections = + collapseConnections !== null ? collapseConnections : !node.properties["collapse_connections"]; + node.properties["collapse_connections"] = collapseConnections; +} + +export function getConnectionPosForLayout( + node: LGraphNode, + isInput: boolean, + slotNumber: number, + out: Vector2, +) { + out = out || new Float32Array(2); + node.properties = node.properties || {}; + const layout = node.properties["connections_layout"] || + (node as any).defaultConnectionsLayout || ["Left", "Right"]; + const collapseConnections = node.properties["collapse_connections"] || false; + const offset = (node.constructor as any).layout_slot_offset ?? LiteGraph.NODE_SLOT_HEIGHT * 0.5; + let side = isInput ? layout[0] : layout[1]; + const otherSide = isInput ? layout[1] : layout[0]; + let data = LAYOUT_LABEL_TO_DATA[side]!; // || LAYOUT_LABEL_TO_DATA[isInput ? 'Left' : 'Right']; + const slotList = node[isInput ? "inputs" : "outputs"]; + const cxn = slotList[slotNumber]; + if (!cxn) { + console.log("No connection found.. weird", isInput, slotNumber); + return out; + } + // Experimental; doesn't work without node.clip_area set (so it won't draw outside), + // but litegraph.core inexplicably clips the title off which we want... so, no go. + // if (cxn.hidden) { + // out[0] = node.pos[0] - 100000 + // out[1] = node.pos[1] - 100000 + // return out + // } + if (cxn.disabled) { + // Let's store the original colors if have them and haven't yet overridden + if (cxn.color_on !== "#666665") { + (cxn as any)._color_on_org = (cxn as any)._color_on_org || cxn.color_on; + (cxn as any)._color_off_org = (cxn as any)._color_off_org || cxn.color_off; + } + cxn.color_on = "#666665"; + cxn.color_off = "#666665"; + } else if (cxn.color_on === "#666665") { + cxn.color_on = (cxn as any)._color_on_org || undefined; + cxn.color_off = (cxn as any)._color_off_org || undefined; + } + const displaySlot = collapseConnections + ? 0 + : slotNumber - + slotList.reduce((count, ioput, index) => { + count += index < slotNumber && ioput.hidden ? 1 : 0; + return count; + }, 0); + // Set the direction first. This is how the connection line will be drawn. + cxn.dir = data[0]; + + // If we are only 10px tall or wide, then look at connections_dir for the direction. + if ((node.size[0] == 10 || node.size[1] == 10) && node.properties["connections_dir"]) { + cxn.dir = node.properties["connections_dir"][isInput ? 0 : 1]!; + } + + if (side === "Left") { + if (node.flags.collapsed) { + var w = (node as any)._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH; + out[0] = node.pos[0]; + out[1] = node.pos[1] - LiteGraph.NODE_TITLE_HEIGHT * 0.5; + } else { + // If we're an output, then the litegraph.core hates us; we need to blank out the name + // because it's not flexible enough to put the text on the inside. + toggleConnectionLabel(cxn, !isInput || collapseConnections || !!(node as any).hideSlotLabels); + out[0] = node.pos[0] + offset; + if ((node.constructor as any)?.type.includes("Reroute")) { + out[1] = node.pos[1] + node.size[1] * 0.5; + } else { + out[1] = + node.pos[1] + + (displaySlot + 0.7) * LiteGraph.NODE_SLOT_HEIGHT + + ((node.constructor as any).slot_start_y || 0); + } + } + } else if (side === "Right") { + if (node.flags.collapsed) { + var w = (node as any)._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH; + out[0] = node.pos[0] + w; + out[1] = node.pos[1] - LiteGraph.NODE_TITLE_HEIGHT * 0.5; + } else { + // If we're an input, then the litegraph.core hates us; we need to blank out the name + // because it's not flexible enough to put the text on the inside. + toggleConnectionLabel(cxn, isInput || collapseConnections || !!(node as any).hideSlotLabels); + out[0] = node.pos[0] + node.size[0] + 1 - offset; + if ((node.constructor as any)?.type.includes("Reroute")) { + out[1] = node.pos[1] + node.size[1] * 0.5; + } else { + out[1] = + node.pos[1] + + (displaySlot + 0.7) * LiteGraph.NODE_SLOT_HEIGHT + + ((node.constructor as any).slot_start_y || 0); + } + } + + // Right now, only reroute uses top/bottom, so this may not work for other nodes + // (like, applying to nodes with titles, collapsed, multiple inputs/outputs, etc). + } else if (side === "Top") { + if (!(cxn as any).has_old_label) { + (cxn as any).has_old_label = true; + (cxn as any).old_label = cxn.label; + cxn.label = " "; + } + out[0] = node.pos[0] + node.size[0] * 0.5; + out[1] = node.pos[1] + offset; + } else if (side === "Bottom") { + if (!(cxn as any).has_old_label) { + (cxn as any).has_old_label = true; + (cxn as any).old_label = cxn.label; + cxn.label = " "; + } + out[0] = node.pos[0] + node.size[0] * 0.5; + out[1] = node.pos[1] + node.size[1] - offset; + } + return out; +} + +function toggleConnectionLabel(cxn: any, hide = true) { + if (hide) { + if (!(cxn as any).has_old_label) { + (cxn as any).has_old_label = true; + (cxn as any).old_label = cxn.label; + } + cxn.label = " "; + } else if (!hide && (cxn as any).has_old_label) { + (cxn as any).has_old_label = false; + cxn.label = (cxn as any).old_label; + (cxn as any).old_label = undefined; + } + return cxn; +} + +export function addHelpMenuItem(node: LGraphNode, content: string, menuOptions: ContextMenuItem[]) { + addMenuItemOnExtraMenuOptions( + node, + { + name: "🛟 Node Help", + callback: (node) => { + if ((node as any).showHelp) { + (node as any).showHelp(); + } else { + new RgthreeHelpDialog(node, content).show(); + } + }, + }, + menuOptions, + "Properties Panel", + ); +} + +export enum PassThroughFollowing { + ALL, + NONE, + REROUTE_ONLY, +} + +/** + * Determines if, when doing a chain lookup for connected nodes, we want to pass through this node, + * like reroutes, etc. + */ +export function shouldPassThrough( + node?: LGraphNode | null, + passThroughFollowing = PassThroughFollowing.ALL, +) { + const type = (node?.constructor as typeof LGraphNode)?.type; + if (!type || passThroughFollowing === PassThroughFollowing.NONE) { + return false; + } + if (passThroughFollowing === PassThroughFollowing.REROUTE_ONLY) { + return type.includes("Reroute"); + } + return ( + type.includes("Reroute") || type.includes("Node Combiner") || type.includes("Node Collector") + ); +} + + +function filterOutPassthroughNodes( + infos: ConnectedNodeInfo[], + passThroughFollowing = PassThroughFollowing.ALL, +) { + return infos.filter((i) => !shouldPassThrough(i.node, passThroughFollowing)); +} + +/** + * Looks through the immediate chain of a node to collect all connected nodes, passing through nodes + * like reroute, etc. Will also disconnect duplicate nodes from a provided node + */ +export function getConnectedInputNodes( + startNode: LGraphNode, + currentNode?: LGraphNode, + slot?: number, + passThroughFollowing = PassThroughFollowing.ALL, +): LGraphNode[] { + return getConnectedNodesInfo( + startNode, + IoDirection.INPUT, + currentNode, + slot, + passThroughFollowing, + ).map((n) => n.node); +} +export function getConnectedInputInfosAndFilterPassThroughs( + startNode: LGraphNode, + currentNode?: LGraphNode, + slot?: number, + passThroughFollowing = PassThroughFollowing.ALL) { + return filterOutPassthroughNodes( + getConnectedNodesInfo(startNode, IoDirection.INPUT, currentNode, slot, passThroughFollowing), + passThroughFollowing); +} +export function getConnectedInputNodesAndFilterPassThroughs( + startNode: LGraphNode, + currentNode?: LGraphNode, + slot?: number, + passThroughFollowing = PassThroughFollowing.ALL, +): LGraphNode[] { + return getConnectedInputInfosAndFilterPassThroughs(startNode, currentNode, slot, passThroughFollowing).map(n => n.node); +} + +export function getConnectedOutputNodes( + startNode: LGraphNode, + currentNode?: LGraphNode, + slot?: number, + passThroughFollowing = PassThroughFollowing.ALL, +): LGraphNode[] { + return getConnectedNodesInfo( + startNode, + IoDirection.OUTPUT, + currentNode, + slot, + passThroughFollowing, + ).map((n) => n.node); +} + +export function getConnectedOutputNodesAndFilterPassThroughs( + startNode: LGraphNode, + currentNode?: LGraphNode, + slot?: number, + passThroughFollowing = PassThroughFollowing.ALL, +): LGraphNode[] { + return filterOutPassthroughNodes( + getConnectedNodesInfo(startNode, IoDirection.OUTPUT, currentNode, slot, passThroughFollowing), + passThroughFollowing, + ).map(n => n.node); +} + +export type ConnectedNodeInfo = { + node: LGraphNode; + travelFromSlot: number; + travelToSlot: number; + originTravelFromSlot: number; +}; + +export function getConnectedNodesInfo( + startNode: LGraphNode, + dir = IoDirection.INPUT, + currentNode?: LGraphNode, + slot?: number, + passThroughFollowing = PassThroughFollowing.ALL, + originTravelFromSlot?: number, +): ConnectedNodeInfo[] { + currentNode = currentNode || startNode; + let rootNodes: ConnectedNodeInfo[] = []; + if (startNode === currentNode || shouldPassThrough(currentNode, passThroughFollowing)) { + let linkIds: Array; + + slot = slot != null && slot > -1 ? slot : undefined; + if (dir == IoDirection.OUTPUT) { + if (slot != null) { + linkIds = [...(currentNode.outputs?.[slot]?.links || [])]; + } else { + linkIds = currentNode.outputs?.flatMap((i) => i.links) || []; + } + } else { + if (slot != null) { + linkIds = [currentNode.inputs?.[slot]?.link]; + } else { + linkIds = currentNode.inputs?.map((i) => i.link) || []; + } + } + let graph = app.graph as LGraph; + for (const linkId of linkIds) { + let link: LLink | null = null; + if (typeof linkId == "number") { + link = graph.links[linkId] as LLink; + } + if (!link) { + continue; + } + const travelFromSlot = dir == IoDirection.OUTPUT ? link.origin_slot : link.target_slot; + const connectedId = dir == IoDirection.OUTPUT ? link.target_id : link.origin_id; + const travelToSlot = dir == IoDirection.OUTPUT ? link.target_slot : link.origin_slot; + originTravelFromSlot = originTravelFromSlot != null ? originTravelFromSlot : travelFromSlot; + const originNode: LGraphNode = graph.getNodeById(connectedId)!; + if (!link) { + console.error("No connected node found... weird"); + continue; + } + if (rootNodes.some((n) => n.node == originNode)) { + console.log( + `${startNode.title} (${startNode.id}) seems to have two links to ${originNode.title} (${ + originNode.id + }). One may be stale: ${linkIds.join(", ")}`, + ); + } else { + // Add the node and, if it's a pass through, let's collect all its nodes as well. + rootNodes.push({ node: originNode, travelFromSlot, travelToSlot, originTravelFromSlot }); + if (shouldPassThrough(originNode, passThroughFollowing)) { + for (const foundNode of getConnectedNodesInfo( + startNode, + dir, + originNode, + undefined, + undefined, + originTravelFromSlot, + )) { + if (!rootNodes.map((n) => n.node).includes(foundNode.node)) { + rootNodes.push(foundNode); + } + } + } + } + } + } + return rootNodes; +} + +export type ConnectionType = { + type: string | string[]; + name: string | undefined; + label: string | undefined; +}; + +/** + * Follows a connection until we find a type associated with a slot. + * `skipSelf` skips the current slot, useful when we may have a dynamic slot that we want to start + * from, but find a type _after_ it (in case it needs to change). + */ +export function followConnectionUntilType( + node: LGraphNode, + dir: IoDirection, + slotNum?: number, + skipSelf = false, +): ConnectionType | null { + const slots = dir === IoDirection.OUTPUT ? node.outputs : node.inputs; + if (!slots || !slots.length) { + return null; + } + let type: ConnectionType | null = null; + if (slotNum) { + if (!slots[slotNum]) { + return null; + } + type = getTypeFromSlot(slots[slotNum], dir, skipSelf); + } else { + for (const slot of slots) { + type = getTypeFromSlot(slot, dir, skipSelf); + if (type) { + break; + } + } + } + return type; +} + +/** + * Gets the type from a slot. If the type is '*' then it will follow the node to find the next slot. + */ +function getTypeFromSlot( + slot: INodeInputSlot | INodeOutputSlot | undefined, + dir: IoDirection, + skipSelf = false, +): ConnectionType | null { + let graph = app.graph as LGraph; + let type = slot?.type; + if (!skipSelf && type != null && type != "*") { + return { type: type as string, label: slot?.label, name: slot?.name }; + } + const links = getSlotLinks(slot); + for (const link of links) { + const connectedId = dir == IoDirection.OUTPUT ? link.link.target_id : link.link.origin_id; + const connectedSlotNum = + dir == IoDirection.OUTPUT ? link.link.target_slot : link.link.origin_slot; + const connectedNode: LGraphNode = graph.getNodeById(connectedId)!; + // Reversed since if we're traveling down the output we want the connected node's input, etc. + const connectedSlots = + dir === IoDirection.OUTPUT ? connectedNode.inputs : connectedNode.outputs; + let connectedSlot = connectedSlots[connectedSlotNum]; + if (connectedSlot?.type != null && connectedSlot?.type != "*") { + return { + type: connectedSlot.type as string, + label: connectedSlot?.label, + name: connectedSlot?.name, + }; + } else if (connectedSlot?.type == "*") { + return followConnectionUntilType(connectedNode, dir); + } + } + return null; +} + +export async function replaceNode( + existingNode: LGraphNode, + typeOrNewNode: string | LGraphNode, + inputNameMap?: Map, +) { + const existingCtor = existingNode.constructor as typeof LGraphNode; + + const newNode = + typeof typeOrNewNode === "string" ? LiteGraph.createNode(typeOrNewNode) : typeOrNewNode; + // Port title (maybe) the position, size, and properties from the old node. + if (existingNode.title != existingCtor.title) { + newNode.title = existingNode.title; + } + newNode.pos = [...existingNode.pos]; + newNode.properties = { ...existingNode.properties }; + const oldComputeSize = [...existingNode.computeSize()]; + // oldSize to use. If we match the smallest size (computeSize) then don't record and we'll use + // the smalles side after conversion. + const oldSize = [ + existingNode.size[0] === oldComputeSize[0] ? null : existingNode.size[0], + existingNode.size[1] === oldComputeSize[1] ? null : existingNode.size[1], + ]; + + let setSizeIters = 0; + const setSizeFn = () => { + // Size gets messed up when ComfyUI adds the text widget, so reset after a delay. + // Since we could be adding many more slots, let's take the larger of the two. + const newComputesize = newNode.computeSize(); + newNode.size[0] = Math.max(oldSize[0] || 0, newComputesize[0]); + newNode.size[1] = Math.max(oldSize[1] || 0, newComputesize[1]); + setSizeIters++; + if (setSizeIters > 10) { + requestAnimationFrame(setSizeFn); + } + }; + setSizeFn(); + + // We now collect the links data, inputs and outputs, of the old node since these will be + // lost when we remove it. + const links: { + node: LGraphNode; + slot: number | string; + targetNode: LGraphNode; + targetSlot: number | string; + }[] = []; + for (const [index, output] of existingNode.outputs.entries()) { + for (const linkId of output.links || []) { + const link: LLink = (app.graph as LGraph).links[linkId]!; + if (!link) continue; + const targetNode = app.graph.getNodeById(link.target_id)!; + links.push({ node: newNode, slot: output.name, targetNode, targetSlot: link.target_slot }); + } + } + for (const [index, input] of existingNode.inputs.entries()) { + const linkId = input.link; + if (linkId) { + const link: LLink = (app.graph as LGraph).links[linkId]!; + const originNode = app.graph.getNodeById(link.origin_id)!; + links.push({ + node: originNode, + slot: link.origin_slot, + targetNode: newNode, + targetSlot: inputNameMap?.has(input.name) + ? inputNameMap.get(input.name)! + : input.name || index, + }); + } + } + // Add the new node, remove the old node. + app.graph.add(newNode); + await wait(); + // Now go through and connect the other nodes up as they were. + for (const link of links) { + link.node.connect(link.slot, link.targetNode, link.targetSlot); + } + await wait(); + app.graph.remove(existingNode); + newNode.size = newNode.computeSize(); + newNode.setDirtyCanvas(true, true); + return newNode; +} + +export function getOriginNodeByLink(linkId?: number | null) { + let node: LGraphNode | null = null; + if (linkId != null) { + const link: LLink = app.graph.links[linkId]!; + node = (link != null && app.graph.getNodeById(link.origin_id)) || null; + } + return node; +} + +export function applyMixins(original: Constructor, constructors: any[]) { + constructors.forEach((baseCtor) => { + Object.getOwnPropertyNames(baseCtor.prototype).forEach((name) => { + Object.defineProperty( + original.prototype, + name, + Object.getOwnPropertyDescriptor(baseCtor.prototype, name) || Object.create(null), + ); + }); + }); +} + +/** + * Retruns a list of `{id: number, link: LLlink}` for a given input or output. + * + * Obviously, for an input, this will be a max of one. + */ +export function getSlotLinks(inputOrOutput?: INodeInputSlot | INodeOutputSlot | null) { + const links: { id: number; link: LLink }[] = []; + if (!inputOrOutput) { + return links; + } + if ((inputOrOutput as INodeOutputSlot).links?.length) { + const output = inputOrOutput as INodeOutputSlot; + for (const linkId of output.links || []) { + const link: LLink = (app.graph as LGraph).links[linkId]!; + if (link) { + links.push({ id: linkId, link: link }); + } + } + } + if ((inputOrOutput as INodeInputSlot).link) { + const input = inputOrOutput as INodeInputSlot; + const link: LLink = (app.graph as LGraph).links[input.link!]!; + if (link) { + links.push({ id: input.link!, link: link }); + } + } + return links; +} + +/** + * Given a node, whether we're dealing with INPUTS or OUTPUTS, and the server data, re-arrange then + * slots to match the order. + */ +export async function matchLocalSlotsToServer( + node: LGraphNode, + direction: IoDirection, + serverNodeData: ComfyObjectInfo, +) { + const serverSlotNames = + direction == IoDirection.INPUT + ? Object.keys(serverNodeData.input?.optional || {}) + : serverNodeData.output_name; + const serverSlotTypes = + direction == IoDirection.INPUT + ? (Object.values(serverNodeData.input?.optional || {}).map((i) => i[0]) as string[]) + : serverNodeData.output; + const slots = direction == IoDirection.INPUT ? node.inputs : node.outputs; + + // Let's go through the node data names and make sure our current ones match, and update if not. + let firstIndex = slots.findIndex((o, i) => i !== serverSlotNames.indexOf(o.name)); + if (firstIndex > -1) { + // Have mismatches. First, let's go through and save all our links by name. + const links: { [key: string]: { id: number; link: LLink }[] } = {}; + slots.map((slot) => { + // There's a chance we have duplicate names on an upgrade, so we'll collect all links to one + // name so we don't ovewrite our list per name. + links[slot.name] = links[slot.name] || []; + links[slot.name]?.push(...getSlotLinks(slot)); + }); + + // Now, go through and rearrange outputs by splicing + for (const [index, serverSlotName] of serverSlotNames.entries()) { + const currentNodeSlot = slots.map((s) => s.name).indexOf(serverSlotName); + if (currentNodeSlot > -1) { + if (currentNodeSlot != index) { + const splicedItem = slots.splice(currentNodeSlot, 1)[0]!; + slots.splice(index, 0, splicedItem as any); + } + } else if (currentNodeSlot === -1) { + const splicedItem = { + name: serverSlotName, + type: serverSlotTypes![index], + links: [], + }; + slots.splice(index, 0, splicedItem as any); + } + } + + if (slots.length > serverSlotNames.length) { + for (let i = slots.length - 1; i > serverSlotNames.length - 1; i--) { + if (direction == IoDirection.INPUT) { + node.disconnectInput(i); + node.removeInput(i); + } else { + node.disconnectOutput(i); + node.removeOutput(i); + } + } + } + + // Now, go through the link data again and make sure the origin_slot is the correct slot. + for (const [name, slotLinks] of Object.entries(links)) { + let currentNodeSlot = slots.map((s) => s.name).indexOf(name); + if (currentNodeSlot > -1) { + for (const linkData of slotLinks) { + if (direction == IoDirection.INPUT) { + linkData.link.target_slot = currentNodeSlot; + } else { + linkData.link.origin_slot = currentNodeSlot; + // If our next node is a Reroute, then let's get it to update the type. + const nextNode = app.graph.getNodeById(linkData.link.target_id); + // (Check nextNode, as sometimes graphs seem to have very stale data and that node id + // doesn't exist). + if ( + nextNode && + (nextNode.constructor as ComfyNodeConstructor)?.type!.includes("Reroute") + ) { + (nextNode as any).stabilize && (nextNode as any).stabilize(); + } + } + } + } + } + } +} + +export function isValidConnection(ioA?: INodeSlot | null, ioB?: INodeSlot | null) { + if (!ioA || !ioB) { + return false; + } + const typeA = String(ioA.type); + const typeB = String(ioB.type); + // What does litegraph think, which includes looking at array values. + let isValid = LiteGraph.isValidConnection(typeA, typeB); + + // This is here to fix the churn happening in list types in comfyui itself.. + // https://github.com/comfyanonymous/ComfyUI/issues/1674 + if (!isValid) { + let areCombos = + (typeA.includes(",") && typeB === "COMBO") || (typeA === "COMBO" && typeB.includes(",")); + // We don't want to let any old combo connect to any old combo, so we'll look at the names too. + if (areCombos) { + // Some nodes use "_name" and some use "model" and "ckpt", so normalize + const nameA = ioA.name.toUpperCase().replace("_NAME", "").replace("CKPT", "MODEL"); + const nameB = ioB.name.toUpperCase().replace("_NAME", "").replace("CKPT", "MODEL"); + isValid = nameA.includes(nameB) || nameB.includes(nameA); + } + } + return isValid; +} + +/** + * Patches the LiteGraph.isValidConnection so old nodes can connect to this new COMBO type for all + * lists (without users needing to go through and re-create all their nodes one by one). + */ +const oldIsValidConnection = LiteGraph.isValidConnection; +LiteGraph.isValidConnection = function (typeA: string | string[], typeB: string | string[]) { + let isValid = oldIsValidConnection.call(LiteGraph, typeA, typeB); + if (!isValid) { + typeA = String(typeA); + typeB = String(typeB); + // This is waaaay too liberal and now any combos can connect to any combos. But we only have the + // types (not names like my util above), and connecting too liberally is better than old nodes + // with lists not being able to connect to this new COMBO type. And, anyway, it matches the + // current behavior today with new nodes anyway, where all lists are COMBO types. + // Refs: https://github.com/comfyanonymous/ComfyUI/issues/1674 + // https://github.com/comfyanonymous/ComfyUI/pull/1675 + let areCombos = + (typeA.includes(",") && typeB === "COMBO") || (typeA === "COMBO" && typeB.includes(",")); + isValid = areCombos; + } + return isValid; +}; diff --git a/rgthree-comfy/src_web/comfyui/utils_canvas.ts b/rgthree-comfy/src_web/comfyui/utils_canvas.ts new file mode 100644 index 0000000000000000000000000000000000000000..1c1797092369f8f43243a6dbcfa581a4c5bcc9f3 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/utils_canvas.ts @@ -0,0 +1,295 @@ +import { app } from "scripts/app.js"; +import type { LGraphCanvas as TLGraphCanvas, Vector2 } from "../typings/litegraph.js"; + +function binarySearch(max: number, getValue: (n: number) => number, match: number) { + let min = 0; + + while (min <= max) { + let guess = Math.floor((min + max) / 2); + const compareVal = getValue(guess); + + if (compareVal === match) return guess; + if (compareVal < match) min = guess + 1; + else max = guess - 1; + } + + return max; +} + +/** + * Fits a string against a max width for a ctx. Font should be defined on ctx beforehand. + */ +export function fitString(ctx: CanvasRenderingContext2D, str: string, maxWidth: number) { + let width = ctx.measureText(str).width; + const ellipsis = "…"; + const ellipsisWidth = measureText(ctx, ellipsis); + if (width <= maxWidth || width <= ellipsisWidth) { + return str; + } + + const index = binarySearch( + str.length, + (guess) => measureText(ctx, str.substring(0, guess)), + maxWidth - ellipsisWidth, + ); + + return str.substring(0, index) + ellipsis; +} + +/** Measures the width of text for a canvas context. */ +export function measureText(ctx: CanvasRenderingContext2D, str: string) { + return ctx.measureText(str).width; +} + +export type WidgetRenderingOptionsPart = { + type?: "toggle" | "custom"; + margin?: number; + fillStyle?: string; + strokeStyle?: string; + lowQuality?: boolean; + draw?(ctx: CanvasRenderingContext2D, x: number, lowQuality: boolean): number; +}; + +type WidgetRenderingOptions = { + width: number; + height: number; + posX?: number; + posY: number; + borderRadius?: number; + colorStroke?: string; + colorBackground?: string; + // node: LGraphNode; + // value?: any; + // margin?: number; + // direction?: "right" | "left"; + // fillStyle?: string; + // strokeStyle?: string; + // parts: WidgetRenderingOptionsPart[]; +}; + +export function isLowQuality() { + const canvas = app.canvas as TLGraphCanvas; + return (canvas.ds?.scale || 1) <= 0.5; +} + +export function drawNodeWidget(ctx: CanvasRenderingContext2D, options: WidgetRenderingOptions) { + const lowQuality = isLowQuality(); + + const data = { + width: options.width, + height: options.height, + posY: options.posY, + lowQuality, + margin: 15, + colorOutline: LiteGraph.WIDGET_OUTLINE_COLOR, + colorBackground: LiteGraph.WIDGET_BGCOLOR, + colorText: LiteGraph.WIDGET_TEXT_COLOR, + colorTextSecondary: LiteGraph.WIDGET_SECONDARY_TEXT_COLOR, + }; + + // Draw background. + ctx.strokeStyle = options.colorStroke || data.colorOutline; + ctx.fillStyle = options.colorBackground || data.colorBackground; + ctx.beginPath(); + ctx.roundRect( + data.margin, + data.posY, + data.width - data.margin * 2, + data.height, + lowQuality ? [0] : options.borderRadius ? [options.borderRadius] : [options.height * 0.5], + ); + ctx.fill(); + if (!lowQuality) { + ctx.stroke(); + } + + return data; +} + +/** Draws a rounded rectangle. */ +export function drawRoundedRectangle( + ctx: CanvasRenderingContext2D, + options: WidgetRenderingOptions, +) { + const lowQuality = isLowQuality(); + options = { ...options }; + ctx.strokeStyle = options.colorStroke || LiteGraph.WIDGET_OUTLINE_COLOR; + ctx.fillStyle = options.colorBackground || LiteGraph.WIDGET_BGCOLOR; + ctx.beginPath(); + ctx.roundRect( + options.posX!, + options.posY, + options.width, + options.height, + lowQuality ? [0] : options.borderRadius ? [options.borderRadius] : [options.height * 0.5], + ); + ctx.fill(); + !lowQuality && ctx.stroke(); +} + +type DrawNumberWidgetPartOptions = { + posX: number; + posY: number; + height: number; + value: number; + direction?: 1 | -1; + textColor?: string; +}; + +/** + * Draws a number picker with arrows off to each side. + * + * This is for internal widgets that may have many hit areas (full-width, default number widgets put + * the arrows on either side of the full-width row). + */ +export function drawNumberWidgetPart( + ctx: CanvasRenderingContext2D, + options: DrawNumberWidgetPartOptions, +): [Vector2, Vector2, Vector2] { + const arrowWidth = 9; + const arrowHeight = 10; + const innerMargin = 3; + const numberWidth = 32; + + const xBoundsArrowLess: Vector2 = [0, 0]; + const xBoundsNumber: Vector2 = [0, 0]; + const xBoundsArrowMore: Vector2 = [0, 0]; + + ctx.save(); + + let posX = options.posX; + const { posY, height, value, textColor } = options; + const midY = posY + height / 2; + + // If we're drawing parts from right to left (usually when something in the middle will be + // flexible), then we can simply move left the expected width of our widget and draw forwards. + if (options.direction === -1) { + posX = posX - arrowWidth - innerMargin - numberWidth - innerMargin - arrowWidth; + } + + // Draw the strength left arrow. + ctx.fill( + new Path2D( + `M ${posX} ${midY} l ${arrowWidth} ${ + arrowHeight / 2 + } l 0 -${arrowHeight} L ${posX} ${midY} z`, + ), + ); + + xBoundsArrowLess[0] = posX; + xBoundsArrowLess[1] = arrowWidth; + posX += arrowWidth + innerMargin; + + // Draw the strength text. + ctx.textAlign = "center"; + ctx.textBaseline = "middle"; + const oldTextcolor = ctx.fillStyle; + if (textColor) { + ctx.fillStyle = textColor; + } + ctx.fillText(fitString(ctx, value.toFixed(2), numberWidth), posX + numberWidth / 2, midY); + ctx.fillStyle = oldTextcolor; + + xBoundsNumber[0] = posX; + xBoundsNumber[1] = numberWidth; + posX += numberWidth + innerMargin; + + // Draw the strength right arrow. + ctx.fill( + new Path2D( + `M ${posX} ${midY - arrowHeight / 2} l ${arrowWidth} ${arrowHeight / 2} l -${arrowWidth} ${ + arrowHeight / 2 + } v -${arrowHeight} z`, + ), + ); + + xBoundsArrowMore[0] = posX; + xBoundsArrowMore[1] = arrowWidth; + + ctx.restore(); + + return [xBoundsArrowLess, xBoundsNumber, xBoundsArrowMore]; +} +drawNumberWidgetPart.WIDTH_TOTAL = 9 + 3 + 32 + 3 + 9; + +type DrawTogglePartOptions = { + posX: number; + posY: number; + height: number; + value: boolean | null; +}; + +/** + * Draws a toggle for a widget. The toggle is a three-way switch with left being false, right being + * true, and a middle state being null. + */ +export function drawTogglePart( + ctx: CanvasRenderingContext2D, + options: DrawTogglePartOptions, +): Vector2 { + const lowQuality = isLowQuality(); + ctx.save(); + + const { posX, posY, height, value } = options; + + const toggleRadius = height * 0.36; // This is the standard toggle height calc. + const toggleBgWidth = height * 1.5; // We don't draw a separate bg, but this would be it. + + // Toggle Track + if (!lowQuality) { + ctx.beginPath(); + ctx.roundRect(posX + 4, posY + 4, toggleBgWidth - 8, height - 8, [height * 0.5]); + ctx.globalAlpha = app.canvas.editor_alpha * 0.25; + ctx.fillStyle = "rgba(255,255,255,0.45)"; + ctx.fill(); + ctx.globalAlpha = app.canvas.editor_alpha; + } + + // Toggle itself + ctx.fillStyle = value === true ? "#89B" : "#888"; + const toggleX = + lowQuality || value === false + ? posX + height * 0.5 + : value === true + ? posX + height + : posX + height * 0.75; + ctx.beginPath(); + ctx.arc(toggleX, posY + height * 0.5, toggleRadius, 0, Math.PI * 2); + ctx.fill(); + + ctx.restore(); + + return [posX, toggleBgWidth]; +} + +export function drawInfoIcon( + ctx: CanvasRenderingContext2D, + x: number, + y: number, + size: number = 12, +) { + ctx.save(); + ctx.beginPath(); + ctx.roundRect(x, y, size, size, [size * 0.1]); + ctx.fillStyle = "#2f82ec"; + ctx.strokeStyle = "#0f2a5e"; + ctx.fill(); + // ctx.stroke(); + ctx.strokeStyle = "#FFF"; + ctx.lineWidth = 2; + // ctx.lineCap = 'round'; + const midX = x + size / 2; + const serifSize = size * 0.175; + ctx.stroke( + new Path2D(` + M ${midX} ${y + size * 0.15} + v 2 + M ${midX - serifSize} ${y + size * 0.45} + h ${serifSize} + v ${size * 0.325} + h ${serifSize} + h -${serifSize * 2} + `), + ); + ctx.restore(); +} diff --git a/rgthree-comfy/src_web/comfyui/utils_inputs_outputs.ts b/rgthree-comfy/src_web/comfyui/utils_inputs_outputs.ts new file mode 100644 index 0000000000000000000000000000000000000000..ae007d57fa157c3c830c2df0901ed1a0478e9d14 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/utils_inputs_outputs.ts @@ -0,0 +1,14 @@ +import type { LGraphNode } from "typings/litegraph.js"; + +/** Removes all inputs from the end. */ +export function removeUnusedInputsFromEnd(node: LGraphNode, minNumber = 1, nameMatch?: RegExp) { + for (let i = node.inputs.length - 1; i >= minNumber; i--) { + if (!node.inputs[i]?.link) { + if (!nameMatch || nameMatch.test(node.inputs[i]!.name)) { + node.removeInput(i); + } + continue; + } + break; + } +} \ No newline at end of file diff --git a/rgthree-comfy/src_web/comfyui/utils_menu.ts b/rgthree-comfy/src_web/comfyui/utils_menu.ts new file mode 100644 index 0000000000000000000000000000000000000000..ca89b59bed1652b1bc880c16b95e36e72a65c2c6 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/utils_menu.ts @@ -0,0 +1,98 @@ +import { app } from "scripts/app.js"; +import { rgthreeApi } from "rgthree/common/rgthree_api.js"; + +import type { + LGraphCanvas as TLGraphCanvas, + IWidget, + LGraphNode, + ContextMenuEventListener, + ContextMenu, + IContextMenuItem, +} from "../typings/litegraph.js"; + +const PASS_THROUGH = function (item: T) { + return item as T; +}; + +/** + * Shows a lora chooser context menu. + */ +export async function showLoraChooser( + event: PointerEvent, + callback: ContextMenuEventListener, + parentMenu?: ContextMenu | null, + loras?: string[], +) { + const canvas = app.canvas as TLGraphCanvas; + if (!loras) { + loras = ["None", ...(await rgthreeApi.getLoras())]; + } + new LiteGraph.ContextMenu(loras, { + event: event, + parentMenu, + title: "Choose a lora", + scale: Math.max(1, canvas.ds?.scale ?? 1), + className: "dark", + callback, + }); +} + +/** + * Shows a context menu chooser of nodes. + * + * @param mapFn The function used to map each node to the context menu item. If null is returned + * it will be filtered out (rather than use a separate filter method). + */ +export function showNodesChooser( + event: PointerEvent, + mapFn: (n: LGraphNode) => T | null, + callback: ContextMenuEventListener, + parentMenu?: ContextMenu, +) { + const canvas = app.canvas as TLGraphCanvas; + const nodesOptions: T[] = (app.graph._nodes as LGraphNode[]) + .map(mapFn) + .filter((e): e is NonNullable => e != null); + + nodesOptions.sort((a: any, b: any) => { + return a.value - b.value; + }); + + new LiteGraph.ContextMenu(nodesOptions, { + event: event, + parentMenu, + title: "Choose a node id", + scale: Math.max(1, canvas.ds?.scale ?? 1), + className: "dark", + callback, + }); +} + +/** + * Shows a conmtext menu chooser for a specific node. + * + * @param mapFn The function used to map each node to the context menu item. If null is returned + * it will be filtered out (rather than use a separate filter method). + */ +export function showWidgetsChooser( + event: PointerEvent | MouseEvent, + node: LGraphNode, + mapFn: (n: IWidget) => T | null, + callback: ContextMenuEventListener, + parentMenu?: ContextMenu, +) { + const options: T[] = (node.widgets || []) + .map(mapFn) + .filter((e): e is NonNullable => e != null); + if (options.length) { + const canvas = app.canvas as TLGraphCanvas; + new LiteGraph.ContextMenu(options, { + event, + parentMenu, + title: "Choose an input/widget", + scale: Math.max(1, canvas.ds?.scale ?? 1), + className: "dark", + callback, + }); + } +} diff --git a/rgthree-comfy/src_web/comfyui/utils_widgets.ts b/rgthree-comfy/src_web/comfyui/utils_widgets.ts new file mode 100644 index 0000000000000000000000000000000000000000..2fb565773a397f5d66de8280f4b856c5aa3551d6 --- /dev/null +++ b/rgthree-comfy/src_web/comfyui/utils_widgets.ts @@ -0,0 +1,461 @@ +import { app } from "scripts/app.js"; +import type { + IWidget, + LGraphNode, + LGraphCanvas as TLGraphCanvas, + Vector2, + AdjustedMouseEvent, + Vector4, +} from "../typings/litegraph.js"; +import { drawNodeWidget, drawRoundedRectangle, fitString, isLowQuality } from "./utils_canvas.js"; + +/** + * Draws a label on teft, and a value on the right, ellipsizing when out of space. + */ +export function drawLabelAndValue( + ctx: CanvasRenderingContext2D, + label: string, + value: string, + width: number, + posY: number, + height: number, + options?: { offsetLeft: number }, +) { + const outerMargin = 15; + const innerMargin = 10; + const midY = posY + height / 2; + ctx.save(); + ctx.textAlign = "left"; + ctx.textBaseline = "middle"; + ctx.fillStyle = LiteGraph.WIDGET_SECONDARY_TEXT_COLOR; + const labelX = outerMargin + innerMargin + (options?.offsetLeft ?? 0); + ctx.fillText(label, labelX, midY); + + const valueXLeft = labelX + ctx.measureText(label).width + 7; + const valueXRight = width - (outerMargin + innerMargin); + + ctx.fillStyle = LiteGraph.WIDGET_TEXT_COLOR; + ctx.textAlign = "right"; + ctx.fillText(fitString(ctx, value, valueXRight - valueXLeft), valueXRight, midY); + ctx.restore(); +} + +export type RgthreeBaseWidgetBounds = { + /** The bounds, either [x, width] assuming the full height, or [x, y, width, height] if height. */ + bounds: Vector2 | Vector4; + onDown?(event: AdjustedMouseEvent, pos: Vector2, node: LGraphNode): boolean | void; + onDown?( + event: AdjustedMouseEvent, + pos: Vector2, + node: LGraphNode, + bounds: RgthreeBaseWidgetBounds, + ): boolean | void; + onUp?(event: AdjustedMouseEvent, pos: Vector2, node: LGraphNode): boolean | void; + onUp?( + event: AdjustedMouseEvent, + pos: Vector2, + node: LGraphNode, + bounds: RgthreeBaseWidgetBounds, + ): boolean | void; + onMove?(event: AdjustedMouseEvent, pos: Vector2, node: LGraphNode): boolean | void; + onMove?( + event: AdjustedMouseEvent, + pos: Vector2, + node: LGraphNode, + bounds: RgthreeBaseWidgetBounds, + ): boolean | void; + data?: any; +}; + +export type RgthreeBaseHitAreas = { + [K in Keys]: RgthreeBaseWidgetBounds; +}; + +/** + * A base widget that handles mouse events more properly. + */ +export abstract class RgthreeBaseWidget implements IWidget { + // We don't want our value to be an array as a widget will be serialized as an "input" for the API + // which uses an array value to represent a link. To keep things simpler, we'll avoid using an + // array at all. + abstract value: T extends Array ? never : T; + + name: string; + last_y: number = 0; + + protected mouseDowned: Vector2 | null = null; + protected isMouseDownedAndOver: boolean = false; + + // protected hitAreas: {[key: string]: RgthreeBaseWidgetBounds} = {}; + protected readonly hitAreas: RgthreeBaseHitAreas = {}; + private downedHitAreasForMove: RgthreeBaseWidgetBounds[] = []; + + constructor(name: string) { + this.name = name; + } + + private clickWasWithinBounds(pos: Vector2, bounds: Vector2 | Vector4) { + let xStart = bounds[0]; + let xEnd = xStart + (bounds.length > 2 ? bounds[2]! : bounds[1]!); + const clickedX = pos[0] >= xStart && pos[0] <= xEnd; + if (bounds.length === 2) { + return clickedX; + } + return clickedX && pos[1] >= bounds[1] && pos[1] <= bounds[1] + bounds[3]; + } + + mouse(event: AdjustedMouseEvent, pos: Vector2, node: LGraphNode) { + const canvas = app.canvas as TLGraphCanvas; + + if (event.type == "pointerdown") { + this.mouseDowned = [...pos]; + this.isMouseDownedAndOver = true; + this.downedHitAreasForMove.length = 0; + // Loop over out bounds data and call any specifics. + let anyHandled = false; + for (const part of Object.values(this.hitAreas)) { + if ((part.onDown || part.onMove) && this.clickWasWithinBounds(pos, part.bounds)) { + if (part.onMove) { + this.downedHitAreasForMove.push(part); + } + if (part.onDown) { + const thisHandled = part.onDown.apply(this, [event, pos, node, part]); + anyHandled = anyHandled || thisHandled == true; + } + } + } + return this.onMouseDown(event, pos, node) ?? anyHandled; + } + + // This only fires when LiteGraph has a node_widget (meaning it's pressed), but we may not be + // the original widget pressed, so we still need `mouseDowned`. + if (event.type == "pointerup") { + if (!this.mouseDowned) return true; + this.downedHitAreasForMove.length = 0; + this.cancelMouseDown(); + let anyHandled = false; + for (const part of Object.values(this.hitAreas)) { + if (part.onUp && this.clickWasWithinBounds(pos, part.bounds)) { + const thisHandled = part.onUp.apply(this, [event, pos, node, part]); + anyHandled = anyHandled || thisHandled == true; + } + } + return this.onMouseUp(event, pos, node) ?? anyHandled; + } + + // This only fires when LiteGraph has a node_widget (meaning it's pressed). + if (event.type == "pointermove") { + this.isMouseDownedAndOver = !!this.mouseDowned; + // If we've moved off the button while pressing, then consider us no longer pressing. + if ( + this.mouseDowned && + (pos[0] < 15 || + pos[0] > node.size[0] - 15 || + pos[1] < this.last_y || + pos[1] > this.last_y + LiteGraph.NODE_WIDGET_HEIGHT) + ) { + this.isMouseDownedAndOver = false; + } + for (const part of this.downedHitAreasForMove) { + part.onMove!.apply(this, [event, pos, node, part]); + } + return this.onMouseMove(event, pos, node) ?? true; + } + return false; + } + + /** Sometimes we want to cancel a mouse down, so that an up/move aren't fired. */ + cancelMouseDown() { + this.mouseDowned = null; + this.isMouseDownedAndOver = false; + this.downedHitAreasForMove.length = 0; + } + + /** An event that fires when the pointer is pressed down (once). */ + onMouseDown(event: AdjustedMouseEvent, pos: Vector2, node: LGraphNode): boolean | void { + return; + } + + /** + * An event that fires when the pointer is let go. Only fires if this was the widget that was + * originally pressed down. + */ + onMouseUp(event: AdjustedMouseEvent, pos: Vector2, node: LGraphNode): boolean | void { + return; + } + + /** + * An event that fires when the pointer is moving after pressing down. Will fire both on and off + * of the widget. Check `isMouseDownedAndOver` to determine if the mouse is currently over the + * widget or not. + */ + onMouseMove(event: AdjustedMouseEvent, pos: Vector2, node: LGraphNode): boolean | void { + return; + } +} + +/** + * A better implementation of the LiteGraph button widget. + */ +export class RgthreeBetterButtonWidget extends RgthreeBaseWidget { + value: string = ""; + mouseUpCallback: (event: AdjustedMouseEvent, pos: Vector2, node: LGraphNode) => boolean | void; + + constructor( + name: string, + mouseUpCallback: (event: AdjustedMouseEvent, pos: Vector2, node: LGraphNode) => boolean | void, + ) { + super(name); + this.mouseUpCallback = mouseUpCallback; + } + + draw(ctx: CanvasRenderingContext2D, node: LGraphNode, width: number, y: number, height: number) { + drawWidgetButton({ctx, node, width, height, y}, this.name, this.isMouseDownedAndOver); + } + + override onMouseUp(event: AdjustedMouseEvent, pos: Vector2, node: LGraphNode) { + return this.mouseUpCallback(event, pos, node); + } +} + +/** + * A better implementation of the LiteGraph text widget, including auto ellipsis. + */ +export class RgthreeBetterTextWidget implements IWidget { + name: string; + value: string; + + constructor(name: string, value: string) { + this.name = name; + this.value = value; + } + + draw(ctx: CanvasRenderingContext2D, node: LGraphNode, width: number, y: number, height: number) { + const widgetData = drawNodeWidget(ctx, { width, height, posY: y }); + + if (!widgetData.lowQuality) { + drawLabelAndValue(ctx, this.name, this.value, width, y, height); + } + } + + mouse(event: MouseEvent, pos: Vector2, node: LGraphNode) { + const canvas = app.canvas as TLGraphCanvas; + if (event.type == "pointerdown") { + canvas.prompt("Label", this.value, (v: string) => (this.value = v), event); + return true; + } + return false; + } +} + +/** + * Options for the Divider Widget. + */ +type RgthreeDividerWidgetOptions = { + marginTop: number; + marginBottom: number; + marginLeft: number; + marginRight: number; + color: string; + thickness: number; +}; + +/** + * A divider widget; can also be used as a spacer if fed a 0 thickness. + */ +export class RgthreeDividerWidget implements IWidget { + options = { serialize: false }; + value = null; + name = "divider"; + + private readonly widgetOptions: RgthreeDividerWidgetOptions = { + marginTop: 7, + marginBottom: 7, + marginLeft: 15, + marginRight: 15, + color: LiteGraph.WIDGET_OUTLINE_COLOR, + thickness: 1, + }; + + constructor(widgetOptions?: Partial) { + Object.assign(this.widgetOptions, widgetOptions || {}); + } + + draw(ctx: CanvasRenderingContext2D, node: LGraphNode, width: number, posY: number, h: number) { + if (this.widgetOptions.thickness) { + ctx.strokeStyle = this.widgetOptions.color; + const x = this.widgetOptions.marginLeft; + const y = posY + this.widgetOptions.marginTop; + const w = width - this.widgetOptions.marginLeft - this.widgetOptions.marginRight; + ctx.stroke(new Path2D(`M ${x} ${y} h ${w}`)); + } + } + + computeSize(width: number): [number, number] { + return [ + width, + this.widgetOptions.marginTop + this.widgetOptions.marginBottom + this.widgetOptions.thickness, + ]; + } +} + +/** + * Options for the Label Widget. + */ +export type RgthreeLabelWidgetOptions = { + align?: "left" | "center" | "right"; + color?: string; + italic?: boolean; + size?: number; + + /** A label to put on the right side. */ + actionLabel?: "__PLUS_ICON__" | string; + actionCallback?: (event: PointerEvent) => void; +}; + +/** + * A simple label widget, drawn with no background. + */ +export class RgthreeLabelWidget implements IWidget { + options = { serialize: false }; + value = null; + name: string; + + private readonly widgetOptions: RgthreeLabelWidgetOptions = {}; + private posY: number = 0; + + constructor(name: string, widgetOptions?: RgthreeLabelWidgetOptions) { + this.name = name; + Object.assign(this.widgetOptions, widgetOptions); + } + + draw( + ctx: CanvasRenderingContext2D, + node: LGraphNode, + width: number, + posY: number, + height: number, + ) { + this.posY = posY; + ctx.save(); + + ctx.textAlign = this.widgetOptions.align || "left"; + ctx.fillStyle = this.widgetOptions.color || LiteGraph.WIDGET_TEXT_COLOR; + const oldFont = ctx.font; + if (this.widgetOptions.italic) { + ctx.font = "italic " + ctx.font; + } + if (this.widgetOptions.size) { + ctx.font = ctx.font.replace(/\d+px/, `${this.widgetOptions.size}px`); + } + + const midY = posY + height / 2; + ctx.textBaseline = "middle"; + + if (this.widgetOptions.align === "center") { + ctx.fillText(this.name, node.size[0] / 2, midY); + } else { + ctx.fillText(this.name, 15, midY); + } // TODO(right); + + ctx.font = oldFont; + + if (this.widgetOptions.actionLabel === "__PLUS_ICON__") { + const plus = new Path2D( + `M${node.size[0] - 15 - 2} ${posY + 7} v4 h-4 v4 h-4 v-4 h-4 v-4 h4 v-4 h4 v4 h4 z`, + ); + ctx.lineJoin = "round"; + ctx.lineCap = "round"; + ctx.fillStyle = "#3a3"; + ctx.strokeStyle = "#383"; + ctx.fill(plus); + ctx.stroke(plus); + } + ctx.restore(); + } + + mouse(event: PointerEvent, nodePos: Vector2, node: LGraphNode) { + if ( + event.type !== "pointerdown" || + isLowQuality() || + !this.widgetOptions.actionLabel || + !this.widgetOptions.actionCallback + ) { + return false; + } + + const pos: Vector2 = [nodePos[0], nodePos[1] - this.posY]; + const rightX = node.size[0] - 15; + if (pos[0] > rightX || pos[0] < rightX - 16) { + return false; + } + this.widgetOptions.actionCallback(event); + return true; + } +} + +/** An invisible widget. */ +export class RgthreeInvisibleWidget implements IWidget { + name: string; + type: string; + value: T; + serializeValue: IWidget['serializeValue'] = undefined; + + constructor(name: string, type: string, value: T, serializeValueFn: ()=> T) { + this.name = name; + this.type = type; + this.value = value; + if (serializeValueFn) { + this.serializeValue = serializeValueFn + } + } + draw() { return; } + computeSize(width: number) : Vector2 { return [0, 0]; } +} + + +type DrawContext = { + ctx: CanvasRenderingContext2D, + node: LGraphNode, + width: number, + y: number, + height: number, +} + +/** + * Draws a better button. + */ +export function drawWidgetButton(drawCtx: DrawContext, text: string, isMouseDownedAndOver: boolean = false) { + // First, add a shadow if we're not down or lowquality. + if (!isLowQuality() && !isMouseDownedAndOver) { + drawRoundedRectangle(drawCtx.ctx, { + width: drawCtx.width - 30 - 2, + height: drawCtx.height, + posY: drawCtx.y + 1, + posX: 15 + 1, + borderRadius: 4, + colorBackground: "#000000aa", + colorStroke: "#000000aa", + }); + } + + drawRoundedRectangle(drawCtx.ctx, { + width: drawCtx.width - 30, + height: drawCtx.height, + posY: drawCtx.y + (isMouseDownedAndOver ? 1 : 0), + posX: 15, + borderRadius: isLowQuality() ? 0 : 4, + colorBackground: isMouseDownedAndOver ? "#444" : LiteGraph.WIDGET_BGCOLOR, + }); + + if (!isLowQuality()) { + drawCtx.ctx.textBaseline = "middle"; + drawCtx.ctx.textAlign = "center"; + drawCtx.ctx.fillStyle = LiteGraph.WIDGET_TEXT_COLOR; + drawCtx.ctx.fillText( + text, + drawCtx.node.size[0] / 2, + drawCtx.y + drawCtx.height / 2 + (isMouseDownedAndOver ? 1 : 0), + ); + } +} \ No newline at end of file diff --git a/rgthree-comfy/src_web/common/css/buttons.scss b/rgthree-comfy/src_web/common/css/buttons.scss new file mode 100644 index 0000000000000000000000000000000000000000..889ada2d29634df2277556832f521dc2bdf1fc1d --- /dev/null +++ b/rgthree-comfy/src_web/common/css/buttons.scss @@ -0,0 +1,106 @@ +:not(#fakeid) .rgthree-button-reset { + position: relative; + appearance: none; + cursor: pointer; + border: 0; + background: transparent; + color: inherit; + padding: 0; + margin: 0; + +} +:not(#fakeid) .rgthree-button { + --padding-top: 7px; + --padding-bottom: 9px; + --padding-x: 16px; + position: relative; + cursor: pointer; + border: 0; + border-radius: 0.25rem; + background: rgba(0, 0, 0, 0.5); + color: white; + font-family: system-ui, sans-serif; + font-size: calc(16rem / 16); + line-height: 1; + white-space: nowrap; + text-decoration: none; + margin: 0.25rem; + box-shadow: 0px 0px 2px rgb(0, 0, 0); + background: #212121; + transition: all 0.1s ease-in-out; + padding: var(--padding-top) var(--padding-x) var(--padding-bottom); + display: inline-flex; + flex-direction: row; + align-items: center; + justify-content: center; + + &::before, + &::after { + content: ""; + display: block; + position: absolute; + border-radius: 0.25rem; + left: 0; + top: 0; + width: 100%; + height: 100%; + box-shadow: + inset 1px 1px 0px rgba(255, 255, 255, 0.12), + inset -1px -1px 0px rgba(0, 0, 0, 0.75); + background: linear-gradient(to bottom, rgba(255, 255, 255, 0.06), rgba(0, 0, 0, 0.15)); + mix-blend-mode: screen; + } + + &::after { + mix-blend-mode: multiply; + } + + &:hover { + background: #303030; + } + &:active { + box-shadow: 0px 0px 0px rgba(0, 0, 0, 0); + background: #121212; + padding: calc(var(--padding-top) + 1px) calc(var(--padding-x) - 1px) + calc(var(--padding-bottom) - 1px) calc(var(--padding-x) + 1px); + } + + &:active::before, + &:active::after { + box-shadow: + 1px 1px 0px rgba(255, 255, 255, 0.15), + inset 1px 1px 0px rgba(0, 0, 0, 0.5), + inset 1px 3px 5px rgba(0, 0, 0, 0.33); + } + + &.-blue { + background: #346599 !important; + } + &.-blue:hover { + background: #3b77b8 !important; + } + &.-blue:active { + background: #1d5086 !important; + } + + &.-green { + background: linear-gradient(to bottom, rgba(255, 255, 255, 0.06), rgba(0, 0, 0, 0.15)), #14580b; + } + &.-green:hover { + background: linear-gradient(to bottom, rgba(255, 255, 255, 0.06), rgba(0, 0, 0, 0.15)), #1a6d0f; + } + &.-green:active { + background: linear-gradient(to bottom, rgba(0, 0, 0, 0.15), rgba(255, 255, 255, 0.06)), #0f3f09; + } + + &[disabled] { + box-shadow: none; + background: #666 !important; + color: #aaa; + pointer-events: none; + } + &[disabled]::before, + &[disabled]::after { + display: none; + } +} \ No newline at end of file diff --git a/rgthree-comfy/src_web/common/css/dialog.scss b/rgthree-comfy/src_web/common/css/dialog.scss new file mode 100644 index 0000000000000000000000000000000000000000..bafd1cfd1b889be752f79fa105615a773fa1e6ba --- /dev/null +++ b/rgthree-comfy/src_web/common/css/dialog.scss @@ -0,0 +1,129 @@ + +.rgthree-dialog { + outline: 0; + border: 0; + border-radius: 6px; + background: #414141; + color: #fff; + box-shadow: + inset 1px 1px 0px rgba(255, 255, 255, 0.05), + inset -1px -1px 0px rgba(0, 0, 0, 0.5), + 2px 2px 20px rgb(0, 0, 0); + max-width: 800px; + box-sizing: border-box; + font-family: "Segoe UI", Tahoma, Geneva, Verdana, sans-serif; + font-size: 1rem; + padding: 0; + max-height: calc(100% - 32px); + + *, *::before, *::after { + box-sizing: inherit; + } +} + +.rgthree-dialog-container { + // padding: 16px; + > * { + padding: 8px 16px; + + &:first-child { + padding-top: 16px; + } + &:last-child { + padding-bottom: 16px; + } + } +} + +.rgthree-dialog.-iconed::after { + content: ""; + font-size: 276px; + position: absolute; + right: 0px; + bottom: 0px; + opacity: 0.15; + display: block; + width: 237px; + overflow: hidden; + height: 186px; + line-height: 1; + pointer-events: none; + z-index: -1; +} +.rgthree-dialog.-iconed.-help::after { + content: "🛟"; +} +.rgthree-dialog.-iconed.-settings::after { + content: "⚙️"; +} + +@media (max-width: 832px) { + .rgthree-dialog { + max-width: calc(100% - 32px); + } +} + +.rgthree-dialog-container-title { + display: flex; + flex-direction: row; + align-items: center; + justify-content: start; +} +.rgthree-dialog-container-title > svg:first-child { + width: 36px; + height: 36px; + margin-right: 16px; +} +.rgthree-dialog-container-title h2 { + font-size: calc(22rem / 16); + margin: 0; + font-weight: bold; +} + +.rgthree-dialog-container-title h2 small { + font-size: calc(13rem / 16); + font-weight: normal; + opacity: 0.75; +} + +.rgthree-dialog-container-content { + overflow: auto; + max-height: calc(100vh - 200px); /* Arbitrary height to copensate for margin, title, and footer.*/ +} +.rgthree-dialog-container-content p { + font-size: calc(13rem / 16); + margin-top: 0; +} + +.rgthree-dialog-container-content ul li p { + margin-bottom: 4px; +} + +.rgthree-dialog-container-content ul li p + p { + margin-top: 0.5em; +} + +.rgthree-dialog-container-content ul li ul { + margin-top: 0.5em; + margin-bottom: 1em; +} + +.rgthree-dialog-container-content p code { + display: inline-block; + padding: 2px 4px; + margin: 0px 2px; + border: 1px solid rgba(255, 255, 255, 0.25); + border-radius: 3px; + background: rgba(255, 255, 255, 0.1); +} + +.rgthree-dialog-container-footer { + display: flex; + align-items: center; + justify-content: center; +} + +body.rgthree-dialog-open > *:not(.rgthree-dialog):not(.rgthree-top-messages-container) { + filter: blur(5px); +} + diff --git a/rgthree-comfy/src_web/common/css/dialog_lora_chooser.scss b/rgthree-comfy/src_web/common/css/dialog_lora_chooser.scss new file mode 100644 index 0000000000000000000000000000000000000000..728cd87af9645230ca7cfba87a6f5a8a7ff0d387 --- /dev/null +++ b/rgthree-comfy/src_web/common/css/dialog_lora_chooser.scss @@ -0,0 +1,161 @@ + +.rgthree-lora-chooser-dialog { + max-width: 100%; + + + .rgthree-dialog-container-title { + display: flex; + flex-direction: column; + } + .rgthree-dialog-container-title h2 { + display: flex; + width: 100%; + } + .rgthree-lora-chooser-search { + margin-left: auto; + border-radius: 50px; + width: 50%; + max-width: 170px; + padding: 2px 8px; + } + + .rgthree-lora-chooser-header { + display: flex; + flex-direction: row; + } + + .rgthree-lora-filters-container { + svg {width: 16px; height: 16px;} + } + + .rgthree-dialog-container-content { + width: 80vw; + height: 80vh; + } + + .rgthree-button-reset { + width: 32px; + height: 32px; + > svg {width: 100%; height: 100%;} + + } + + ul.rgthree-lora-chooser-list { + list-style: none; + margin: 0; + padding: 0; + position: relative; + display: flex; + flex-direction: row; + flex-wrap: wrap; + align-items: start; + justify-content: space-around; + + > li { + position: relative; + flex: 0 0 auto; + width: 170px; + max-width: 100%; + margin: 8px 8px 16px; + + label { + position: absolute; + display: block; + inset: 0; + z-index: 3; + cursor: pointer; + } + input[type="checkbox"] { + position: absolute; + right: 8px; + top: 8px; + margin: 0; + z-index: 2; + appearance: none; + background-color: #fff; + width: 48px; + height: 48px; + border-radius: 4px; + border: 1px solid rgba(120,120,120,1); + opacity: 0; + transition: opacity 0.15s ease-in-out; + + &:checked { + opacity: 1; + background: #0060df; + &::before { + content: ""; + display: block; + width: 100%; + height: 100%; + box-shadow: inset 100px 100px #fff; + clip-path: polygon(40.13% 68.39%, 23.05% 51.31%, 17.83% 48.26%, 12.61% 49.57%, 9.57% 53.04%, 8% 60%, 34.13% 85.87%, 39.82% 89.57%, 45.88% 86.73%, 90.66% 32.39%, 88.92% 26.1%, 83.03% 22.17%, 76.94% 22.62%) + } + } + } + + + figure { + position: relative; + display: block; + margin: 0 0 8px; + padding: 0; + border: 1px solid rgba(120, 120, 120, .8); + background: rgba(120, 120, 120, .5); + width: 100%; + padding-top: 120%; + transition: box-shadow 0.15s ease-in-out; + opacity: 0.75; + &::after { + content: ''; + display: block; + position: absolute; + inset: 0; + } + + &:empty { + &::before { + content: 'No image.'; + color: rgba(200, 200, 200, .8); + position: absolute; + display: block; + inset: 0; + font-size: 1.2em; + text-align: center; + display: flex; + align-items: center; + justify-content: center; + } + } + + > img, > video { + position: absolute; + width: 100%; + height: 100%; + top: 0; + left: 0; + object-fit: cover; + } + } + div { + word-wrap: break-word; + font-size: 0.8rem; + opacity: 0.75; + } + + &:hover figure::after{ + box-shadow: 0px 2px 6px rgba(0,0,0,0.75); + } + :checked ~ figure::after { + box-shadow: 0 0 5px #fff, 0px 0px 15px rgba(49, 131, 255, 0.88), inset 0 0 3px #fff, inset 0px 0px 5px rgba(49, 131, 255, 0.88) + } + + &:hover *, + &:hover input[type="checkbox"], + :checked ~ * { + opacity: 1 + } + + } + } +} \ No newline at end of file diff --git a/rgthree-comfy/src_web/common/css/dialog_model_info.scss b/rgthree-comfy/src_web/common/css/dialog_model_info.scss new file mode 100644 index 0000000000000000000000000000000000000000..283767f068b75bc3032194ca386b51b6f923a878 --- /dev/null +++ b/rgthree-comfy/src_web/common/css/dialog_model_info.scss @@ -0,0 +1,396 @@ + +.rgthree-info-dialog { + + width: 90vw; + max-width: 960px; + + .rgthree-info-area { + list-style: none; + padding: 0; + margin: 0; + display: flex; + + > li { + display: inline-flex; + margin: 0; + vertical-align: top; + + + li { + margin-left: 6px; + } + &:not(.-link) + li.-link { + margin-left: auto; + } + + &.rgthree-info-tag > * { + min-height: 24px; + border-radius: 4px; + line-height: 1; + color: rgba(255,255,255,0.85); + background: rgb(69, 92, 85);; + font-size: 14px; + font-weight: bold; + text-decoration: none; + display: flex; + height: 1.6em; + padding-left: .5em; + padding-right: .5em; + padding-bottom: .1em; + align-content: center; + justify-content: center; + align-items: center; + box-shadow: inset 0px 0px 0 1px rgba(0, 0, 0, 0.5); + + > svg { + width: 16px; + height: 16px; + + &:last-child { + margin-left: .5em; + } + } + + + &[href] { + box-shadow: inset 0px 1px 0px rgba(255,255,255,0.25), inset 0px -1px 0px rgba(0,0,0,0.66); + } + + &:empty { + display: none; + } + } + + // &.-civitai > * { + // color: #ddd; + // background: #1b65aa; + // transition: all 0.15s ease-in-out; + // &:hover { + // color: #fff; + // border-color: #1971c2; + // background: #1971c2; + // } + // } + &.-type > * { + background: rgb(73, 54, 94); + color: rgb(228, 209, 248); + } + + &.rgthree-info-menu { + margin-left: auto; + + :not(#fakeid) & .rgthree-button { + margin: 0; + min-height: 24px; + padding: 0 12px; + } + + svg { + width: 16px; + height: 16px; + } + } + } + } + + .rgthree-info-table { + border-collapse: collapse; + margin: 16px 0px; + width: 100%; + font-size: 12px; + + tr.editable button { + display: flex; + width: 28px; + height: 28px; + align-items: center; + justify-content: center; + + svg + svg {display: none;} + } + tr.editable.-rgthree-editing button { + svg {display: none;} + svg + svg {display: inline-block;} + } + + td { + position: relative; + border: 1px solid rgba(255,255,255,0.25); + padding: 0; + vertical-align: top; + + &:first-child { + background: rgba(255,255,255,0.075); + width: 10px; // Small, so it doesn't adjust. + > *:first-child { + white-space: nowrap; + padding-right: 32px; + } + + small { + display: block; + margin-top: 2px; + opacity: 0.75; + + > [data-action] { + text-decoration: underline; + cursor: pointer; + &:hover { + text-decoration: none; + } + } + } + } + + a, a:hover, a:visited { + color: inherit; + } + + svg { + width: 1.3333em; + height: 1.3333em; + vertical-align: -0.285em; + + &.logo-civitai { + margin-right: 0.3333em; + } + } + + > *:first-child { + display: block; + padding: 6px 10px; + } + + > input, > textarea{ + padding: 5px 10px; + border: 0; + box-shadow: inset 1px 1px 5px 0px rgba(0,0,0,0.5); + font: inherit; + appearance: none; + background: #fff; + color: #121212; + resize: vertical; + + &:only-child { + width: 100%; + } + } + + :not(#fakeid) & .rgthree-button[data-action="fetch-civitai"] { + font-size: inherit; + padding: 6px 16px; + margin: 2px; + } + } + + tr[data-field-name="userNote"] td > span:first-child { + white-space: pre; + } + + tr.rgthree-info-table-break-row td { + border: 0; + background: transparent; + padding: 12px 4px 4px; + font-size: 1.2em; + + > small { + font-style: italic; + opacity: 0.66; + } + + &:empty { + padding: 4px; + } + } + + td .-help { + border: 1px solid currentColor; + position: absolute; + right: 5px; + top: 6px; + line-height: 1; + font-size: 11px; + width: 12px; + height: 12px; + border-radius: 8px; + display: flex; + align-content: center; + justify-content: center; + cursor: help; + &::before { + content: '?'; + } + + } + + td > ul.rgthree-info-trained-words-list { + list-style: none; + padding: 2px 8px; + margin: 0; + display: flex; + flex-direction: row; + flex-wrap: wrap; + max-height: 15vh; + overflow: auto; + + > li { + display: inline-flex; + margin: 2px; + vertical-align: top; + border-radius: 4px; + line-height: 1; + color: rgba(255,255,255,0.85); + background: rgb(73, 91, 106); + font-size: 1.2em; + font-weight: 600; + text-decoration: none; + display: flex; + height: 1.6em; + align-content: center; + justify-content: center; + align-items: center; + box-shadow: inset 0px 0px 0 1px rgba(0, 0, 0, 0.5); + cursor: pointer; + white-space: nowrap; + max-width: 183px; + + &:hover { + background: rgb(68, 109, 142); + } + + > svg { + width: auto; + height: 1.2em; + } + + > span { + padding-left: .5em; + padding-right: .5em; + padding-bottom: .1em; + text-overflow: ellipsis; + overflow: hidden; + } + + > small { + align-self: stretch; + display: flex; + align-items: center; + justify-content: center; + padding: 0 0.5em; + background: rgba(0,0,0,0.2); + } + + &.-rgthree-is-selected { + background: rgb(42, 126, 193); + } + } + } + } + + .rgthree-info-images { + list-style:none; + padding:0; + margin:0; + scroll-snap-type: x mandatory; + display:flex; + flex-direction:row; + overflow: auto; + + > li { + scroll-snap-align: start; + max-width: 90%; + flex: 0 0 auto; + display: flex; + align-items: center; + justify-content: center; + flex-direction: column; + overflow: hidden; + padding: 0; + margin: 6px; + font-size: 0; + position: relative; + + figure { + margin: 0; + position: static; + + figcaption { + position: absolute; + left: 0; + width: 100%; + bottom: 0; + padding: 12px; + font-size: 12px; + background: rgba(0,0,0,0.85); + opacity: 0; + transform: translateY(50px); + transition: all 0.25s ease-in-out; + + > span { + display: inline-block; + padding: 2px 4px; + margin: 2px; + border-radius: 2px; + border: 1px solid rgba(255,255,255,0.2); + word-break: break-word; + + label { + display: inline; + padding: 0; + margin: 0; + opacity: 0.5; + pointer-events: none; + user-select: none; + } + a { + color: inherit; + text-decoration: underline; + &:hover { + text-decoration: none; + } + + svg { + height: 10px; + margin-left: 4px; + fill: currentColor; + } + } + } + &:empty { + text-align: center; + + &::before { + content: 'No data.'; + } + } + } + } + + &:hover figure figcaption { + opacity: 1; + transform: translateY(0px); + } + + .rgthree-info-table { + width: calc(100% - 16px); + } + } + } + + .rgthree-info-civitai-link { + margin: 8px; + color: #eee; + + a, a:hover, a:visited { + color: inherit; + text-decoration: none; + } + + > svg { + width: 16px; + height: 16px; + margin-right: 8px; + } + } +} + + diff --git a/rgthree-comfy/src_web/common/css/menu.scss b/rgthree-comfy/src_web/common/css/menu.scss new file mode 100644 index 0000000000000000000000000000000000000000..40a4d5164775d9e131641cd120d4ef31109a7359 --- /dev/null +++ b/rgthree-comfy/src_web/common/css/menu.scss @@ -0,0 +1,108 @@ + + +.rgthree-menu { + list-style: none; + padding: 0; + margin: 0; + position: fixed; + z-index: 999999; + pointer-events: none; + opacity: 0; + transition: opacity 0.08s ease-in-out; + + color: #dde; + background-color: #111; + font-size: 12px; + box-shadow: 0 0 10px black !important; + + > li { + position: relative; + padding: 4px 6px; + z-index: 9999; + white-space: nowrap; + + &[role="button"] { + background-color: var(--comfy-menu-bg) !important; + color: var(--input-text); + cursor: pointer; + &:hover { + filter: brightness(155%); + } + } + } + + &[state^="measuring"] { + display: block; + opacity: 0; + } + &[state="open"] { + display: block; + opacity: 1; + pointer-events: all; + } +} + + +.rgthree-top-menu { + box-sizing: border-box; + white-space: nowrap; + background: var(--content-bg); + color: var(--content-fg); + display: flex; + flex-direction: column; + * { + box-sizing: inherit; + } + + menu { + list-style: none; + padding: 0; + margin: 0; + + > li:not(#fakeid) { + list-style: none; + padding: 0; + margin: 0; + + > button { + cursor: pointer; + padding: 8px 12px 8px 8px; + width: 100%; + text-align: start; + display: flex; + flex-direction: row; + align-items: center; + justify-content: start; + + &:hover { + background-color: var(--comfy-input-bg); + } + + svg { + height: 16px; + width: auto; + margin-inline-end: 0.6em; + + &.github-star { + fill: rgb(227, 179, 65); + } + } + } + + &.rgthree-message { + // ComfyUI's code has strange behavior that that always puts the popupat to if its less than + // 30px... we'll force our message to be at least 32px tall so it won't do that unless it's + // actually on the bottom. + min-height: 32px; + > span { + padding: 8px 12px; + display: block; + width: 100%; + text-align: center; + font-style: italic; + font-size: 12px; + } + } + } + } +} diff --git a/rgthree-comfy/src_web/common/css/pages_base.scss b/rgthree-comfy/src_web/common/css/pages_base.scss new file mode 100644 index 0000000000000000000000000000000000000000..12d883af885dd2374af9d963b43d54e034d12629 --- /dev/null +++ b/rgthree-comfy/src_web/common/css/pages_base.scss @@ -0,0 +1,69 @@ + +html, body { + +} +html { + font-size: 100%; + overflow-y: scroll; + -webkit-text-size-adjust: 100%; + -ms-text-size-adjust: 100%; + box-sizing: border-box; +} +*, *:before, *:after { + box-sizing: inherit +} + +:root { + --header-height: 56px; + --progress-height: 12px; +} + +button { + all: unset; +} + +.-bevel { + position: relative; +} +.-bevel::before { + content: ''; + position: absolute; + left: 0; + top: 0; + width: 100%; + height: 100%; + border: 1px solid red; + border-color: rgba(255,255,255,0.15) rgba(255,255,255,0.15) rgba(0,0,0,0.5) rgba(0,0,0,0.5); + z-index: 5; + pointer-events: none; +} + + +body { + background: #202020; + font-family: Arial, sans-serif; + font-size: calc(16 * 0.0625rem); + font-weight: 400; + margin: 0; + padding-top: calc(var(--header-height) + var(--progress-height)); + color: #ffffff; + display: flex; + flex-direction: column; + align-items: center; + justify-content: start; +} + +.app-header { + height: var( --header-height); + padding: 0; + position: fixed; + z-index: 99; + top: 0; + left: 0; + width: 100%; + background: #353535; + display: flex; + flex-direction: row; + align-items: center; + justify-content: start; +} diff --git a/rgthree-comfy/src_web/common/dialog.ts b/rgthree-comfy/src_web/common/dialog.ts new file mode 100644 index 0000000000000000000000000000000000000000..5087196e7bd5e6810cfcdc16cbfefab4bac406d9 --- /dev/null +++ b/rgthree-comfy/src_web/common/dialog.ts @@ -0,0 +1,165 @@ +import type { LGraphNode, LGraphNodeConstructor } from "typings/litegraph.js"; +import { createElement as $el, getClosestOrSelf, setAttributes } from "./utils_dom.js"; + +type RgthreeDialogButton = { + label: string; + className?: string; + closes?: boolean; + disabled?: boolean; + callback?: (e: PointerEvent | MouseEvent) => void; +}; + +export type RgthreeDialogOptions = { + content: string | HTMLElement | HTMLElement[]; + class?: string | string[]; + title?: string | HTMLElement | HTMLElement[]; + closeX?: boolean; + closeOnEsc?: boolean; + closeOnModalClick?: boolean; + closeButtonLabel?: string | boolean; + buttons?: RgthreeDialogButton[]; + onBeforeClose?: () => Promise | boolean; +}; + +/** + * A Dialog that shows content, and closes. + */ +export class RgthreeDialog extends EventTarget { + element: HTMLDialogElement; + contentElement: HTMLDivElement; + titleElement: HTMLDivElement; + options: RgthreeDialogOptions; + + constructor(options: RgthreeDialogOptions) { + super(); + this.options = options; + let container = $el("div.rgthree-dialog-container"); + this.element = $el("dialog", { + classes: ["rgthree-dialog", options.class || ""], + child: container, + parent: document.body, + events: { + click: (event: MouseEvent) => { + // Close the dialog if we've clicked outside of our container. The dialog modal will + // report itself as the dialog itself, so we use the inner container div (and CSS to + // remove default padding from the dialog element). + if ( + !this.element.open || + event.target === container || + getClosestOrSelf(event.target, `.rgthree-dialog-container`) === container + ) { + return; + } + return this.close(); + }, + }, + }); + this.element.addEventListener("close", (event) => { + this.onDialogElementClose(); + }); + + this.titleElement = $el("div.rgthree-dialog-container-title", { + parent: container, + children: !options.title + ? null + : options.title instanceof Element || Array.isArray(options.title) + ? options.title + : typeof options.title === "string" + ? !options.title.includes(" { + button.callback?.(e); + }, + }, + }); + } + + if (options.closeButtonLabel !== false) { + $el("button", { + text: options.closeButtonLabel || "Close", + className: "rgthree-button", + parent: footerEl, + events: { + click: (e: MouseEvent) => { + this.close(e); + }, + }, + }); + } + } + + setTitle(content: string | HTMLElement | HTMLElement[]) { + const title = + typeof content !== "string" || content.includes(" = {}, + ) { + const title = (node.type || node.title || "").replace( + /\s*\(rgthree\).*/, + " by rgthree", + ); + const options = Object.assign({}, opts, { + class: "-iconed -help", + title, + content, + }); + super(options); + } +} diff --git a/rgthree-comfy/src_web/common/link_fixer.ts b/rgthree-comfy/src_web/common/link_fixer.ts new file mode 100644 index 0000000000000000000000000000000000000000..54491da5df3fe467e791f4676f0111c38ea36eb6 --- /dev/null +++ b/rgthree-comfy/src_web/common/link_fixer.ts @@ -0,0 +1,392 @@ +import type { BadLinksData, SerializedGraph, SerializedLink, SerializedNode } from "typings/index.js"; +import type { LGraph, LGraphNode, LLink, serializedLGraph } from "typings/litegraph.js"; + +enum IoDirection { + INPUT, + OUTPUT, +} + +function getNodeById(graph: SerializedGraph | LGraph | serializedLGraph, id: number) { + if ((graph as LGraph).getNodeById) { + return (graph as LGraph).getNodeById(id); + } + graph = graph as SerializedGraph; + return graph.nodes.find((n) => n.id === id)!; +} + +function extendLink(link: SerializedLink) { + return { + link: link, + id: link[0], + origin_id: link[1], + origin_slot: link[2], + target_id: link[3], + target_slot: link[4], + type: link[5], + }; +} + +/** + * Takes a SerializedGraph or live LGraph and inspects the links and nodes to ensure the linking + * makes logical sense. Can apply fixes when passed the `fix` argument as true. + * + * Note that fixes are a best-effort attempt. Seems to get it correct in most cases, but there is a + * chance it correct an anomoly that results in placing an incorrect link (say, if there were two + * links in the data). Users should take care to not overwrite work until manually checking the + * result. + */ +export function fixBadLinks( + graph: SerializedGraph | LGraph, + fix = false, + silent = false, + logger: { log: (...args: any[]) => void } = console, +): BadLinksData { + const patchedNodeSlots: { + [nodeId: string]: { + inputs?: { [slot: number]: number | null }; + outputs?: { + [slots: number]: { + links: number[]; + changes: { [linkId: number]: "ADD" | "REMOVE" }; + }; + }; + }; + } = {}; + // const logger = this.newLogSession("[findBadLinks]"); + const data: { patchedNodes: Array; deletedLinks: number[] } = { + patchedNodes: [], + deletedLinks: [], + }; + + /** + * Internal patch node. We keep track of changes in patchedNodeSlots in case we're in a dry run. + */ + async function patchNodeSlot( + node: SerializedNode | LGraphNode, + ioDir: IoDirection, + slot: number, + linkId: number, + op: "ADD" | "REMOVE", + ) { + patchedNodeSlots[node.id] = patchedNodeSlots[node.id] || {}; + const patchedNode = patchedNodeSlots[node.id]!; + if (ioDir == IoDirection.INPUT) { + patchedNode["inputs"] = patchedNode["inputs"] || {}; + // We can set to null (delete), so undefined means we haven't set it at all. + if (patchedNode["inputs"]![slot] !== undefined) { + !silent && + logger.log( + ` > Already set ${node.id}.inputs[${slot}] to ${patchedNode["inputs"]![ + slot + ]!} Skipping.`, + ); + return false; + } + let linkIdToSet = op === "REMOVE" ? null : linkId; + patchedNode["inputs"]![slot] = linkIdToSet; + if (fix) { + // node.inputs[slot]!.link = linkIdToSet; + } + } else { + patchedNode["outputs"] = patchedNode["outputs"] || {}; + patchedNode["outputs"]![slot] = patchedNode["outputs"]![slot] || { + links: [...(node.outputs?.[slot]?.links || [])], + changes: {}, + }; + if (patchedNode["outputs"]![slot]!["changes"]![linkId] !== undefined) { + !silent && + logger.log( + ` > Already set ${node.id}.outputs[${slot}] to ${ + patchedNode["inputs"]![slot] + }! Skipping.`, + ); + return false; + } + patchedNode["outputs"]![slot]!["changes"]![linkId] = op; + if (op === "ADD") { + let linkIdIndex = patchedNode["outputs"]![slot]!["links"].indexOf(linkId); + if (linkIdIndex !== -1) { + !silent && logger.log(` > Hmmm.. asked to add ${linkId} but it is already in list...`); + return false; + } + patchedNode["outputs"]![slot]!["links"].push(linkId); + if (fix) { + node.outputs = node.outputs || []; + node.outputs[slot] = node.outputs[slot] || ({} as any); + node.outputs[slot]!.links = node.outputs[slot]!.links || []; + node.outputs[slot]!.links!.push(linkId); + } + } else { + let linkIdIndex = patchedNode["outputs"]![slot]!["links"].indexOf(linkId); + if (linkIdIndex === -1) { + !silent && logger.log(` > Hmmm.. asked to remove ${linkId} but it doesn't exist...`); + return false; + } + patchedNode["outputs"]![slot]!["links"].splice(linkIdIndex, 1); + if (fix) { + node.outputs?.[slot]!.links!.splice(linkIdIndex, 1); + } + } + } + data.patchedNodes.push(node); + return true; + } + + /** + * Internal to check if a node (or patched data) has a linkId. + */ + function nodeHasLinkId( + node: SerializedNode | LGraphNode, + ioDir: IoDirection, + slot: number, + linkId: number, + ) { + // Patched data should be canonical. We can double check if fixing too. + let has = false; + if (ioDir === IoDirection.INPUT) { + let nodeHasIt = node.inputs?.[slot]?.link === linkId; + if (patchedNodeSlots[node.id]?.["inputs"]) { + let patchedHasIt = patchedNodeSlots[node.id]!["inputs"]![slot] === linkId; + // If we're fixing, double check that node matches. + if (fix && nodeHasIt !== patchedHasIt) { + throw Error("Error. Expected node to match patched data."); + } + has = patchedHasIt; + } else { + has = !!nodeHasIt; + } + } else { + let nodeHasIt = node.outputs?.[slot]?.links?.includes(linkId); + if (patchedNodeSlots[node.id]?.["outputs"]?.[slot]?.["changes"][linkId]) { + let patchedHasIt = patchedNodeSlots[node.id]!["outputs"]![slot]?.links.includes(linkId); + // If we're fixing, double check that node matches. + if (fix && nodeHasIt !== patchedHasIt) { + throw Error("Error. Expected node to match patched data."); + } + has = !!patchedHasIt; + } else { + has = !!nodeHasIt; + } + } + return has; + } + + /** + * Internal to check if a node (or patched data) has a linkId. + */ + function nodeHasAnyLink(node: SerializedNode | LGraphNode, ioDir: IoDirection, slot: number) { + // Patched data should be canonical. We can double check if fixing too. + let hasAny = false; + if (ioDir === IoDirection.INPUT) { + let nodeHasAny = node.inputs?.[slot]?.link != null; + if (patchedNodeSlots[node.id]?.["inputs"]) { + let patchedHasAny = patchedNodeSlots[node.id]!["inputs"]![slot] != null; + // If we're fixing, double check that node matches. + if (fix && nodeHasAny !== patchedHasAny) { + throw Error("Error. Expected node to match patched data."); + } + hasAny = patchedHasAny; + } else { + hasAny = !!nodeHasAny; + } + } else { + let nodeHasAny = node.outputs?.[slot]?.links?.length; + if (patchedNodeSlots[node.id]?.["outputs"]?.[slot]?.["changes"]) { + let patchedHasAny = patchedNodeSlots[node.id]!["outputs"]![slot]?.links.length; + // If we're fixing, double check that node matches. + if (fix && nodeHasAny !== patchedHasAny) { + throw Error("Error. Expected node to match patched data."); + } + hasAny = !!patchedHasAny; + } else { + hasAny = !!nodeHasAny; + } + } + return hasAny; + } + + let links: Array = []; + if (!Array.isArray(graph.links)) { + Object.values(graph.links).reduce((acc, v) => { + acc[v.id] = v; + return acc; + }, links); + } else { + links = graph.links; + } + + const linksReverse = [...links]; + linksReverse.reverse(); + for (let l of linksReverse) { + if (!l) continue; + const link = (l as LLink).origin_slot != null ? (l as LLink) : extendLink(l as SerializedLink); + + const originNode = getNodeById(graph, link.origin_id); + const originHasLink = () => + nodeHasLinkId(originNode!, IoDirection.OUTPUT, link.origin_slot, link.id); + const patchOrigin = (op: "ADD" | "REMOVE", id = link.id) => + patchNodeSlot(originNode!, IoDirection.OUTPUT, link.origin_slot, id, op); + + const targetNode = getNodeById(graph, link.target_id); + const targetHasLink = () => + nodeHasLinkId(targetNode!, IoDirection.INPUT, link.target_slot, link.id); + const targetHasAnyLink = () => nodeHasAnyLink(targetNode!, IoDirection.INPUT, link.target_slot); + const patchTarget = (op: "ADD" | "REMOVE", id = link.id) => + patchNodeSlot(targetNode!, IoDirection.INPUT, link.target_slot, id, op); + + const originLog = `origin(${link.origin_id}).outputs[${link.origin_slot}].links`; + const targetLog = `target(${link.target_id}).inputs[${link.target_slot}].link`; + + if (!originNode || !targetNode) { + if (!originNode && !targetNode) { + !silent && + logger.log( + `Link ${link.id} is invalid, ` + + `both origin ${link.origin_id} and target ${link.target_id} do not exist`, + ); + } else if (!originNode) { + !silent && + logger.log( + `Link ${link.id} is funky... ` + + `origin ${link.origin_id} does not exist, but target ${link.target_id} does.`, + ); + if (targetHasLink()) { + !silent && + logger.log( + ` > [PATCH] ${targetLog} does have link, will remove the inputs' link first.`, + ); + patchTarget("REMOVE", -1); + } + } else if (!targetNode) { + !silent && + logger.log( + `Link ${link.id} is funky... ` + + `target ${link.target_id} does not exist, but origin ${link.origin_id} does.`, + ); + if (originHasLink()) { + !silent && + logger.log(` > [PATCH] Origin's links' has ${link.id}; will remove the link first.`); + patchOrigin("REMOVE"); + } + } + continue; + } + + if (targetHasLink() || originHasLink()) { + if (!originHasLink()) { + !silent && + logger.log( + `${link.id} is funky... ${originLog} does NOT contain it, but ${targetLog} does.`, + ); + !silent && + logger.log(` > [PATCH] Attempt a fix by adding this ${link.id} to ${originLog}.`); + patchOrigin("ADD"); + } else if (!targetHasLink()) { + !silent && + logger.log( + `${link.id} is funky... ${targetLog} is NOT correct (is ${targetNode.inputs?.[ + link.target_slot + ]?.link}), but ${originLog} contains it`, + ); + if (!targetHasAnyLink()) { + !silent && logger.log(` > [PATCH] ${targetLog} is not defined, will set to ${link.id}.`); + let patched = patchTarget("ADD"); + if (!patched) { + !silent && + logger.log( + ` > [PATCH] Nvm, ${targetLog} already patched. Removing ${link.id} from ${originLog}.`, + ); + patched = patchOrigin("REMOVE"); + } + } else { + !silent && + logger.log( + ` > [PATCH] ${targetLog} is defined, removing ${link.id} from ${originLog}.`, + ); + patchOrigin("REMOVE"); + } + } + } + } + + // Now that we've cleaned up the inputs, outputs, run through it looking for dangling links., + for (let l of linksReverse) { + if (!l) continue; + const link = (l as LLink).origin_slot != null ? (l as LLink) : extendLink(l as SerializedLink); + const originNode = getNodeById(graph, link.origin_id); + const targetNode = getNodeById(graph, link.target_id); + // Now that we've manipulated the linking, check again if they both exist. + if ( + (!originNode || !nodeHasLinkId(originNode, IoDirection.OUTPUT, link.origin_slot, link.id)) && + (!targetNode || !nodeHasLinkId(targetNode, IoDirection.INPUT, link.target_slot, link.id)) + ) { + !silent && + logger.log( + `${link.id} is def invalid; BOTH origin node ${link.origin_id} ${ + !originNode ? "is removed" : `doesn\'t have ${link.id}` + } and ${link.origin_id} target node ${ + !targetNode ? "is removed" : `doesn\'t have ${link.id}` + }.`, + ); + data.deletedLinks.push(link.id); + continue; + } + } + + // If we're fixing, then we've been patching along the way. Now go through and actually delete + // the zombie links from `app.graph.links` + if (fix) { + for (let i = data.deletedLinks.length - 1; i >= 0; i--) { + !silent && logger.log(`Deleting link #${data.deletedLinks[i]}.`); + if ((graph as LGraph).getNodeById) { + delete graph.links[data.deletedLinks[i]!]; + } else { + graph = graph as SerializedGraph; + // Sometimes we got objects for links if passed after ComfyUI's loadGraphData modifies the + // data. We make a copy now, but can handle the bastardized objects just in case. + const idx = graph.links.findIndex( + (l) => l && (l[0] === data.deletedLinks[i] || (l as any).id === data.deletedLinks[i]), + ); + if (idx === -1) { + logger.log(`INDEX NOT FOUND for #${data.deletedLinks[i]}`); + } + logger.log(`splicing ${idx} from links`); + graph.links.splice(idx, 1); + } + } + // If we're a serialized graph, we can filter out the links because it's just an array. + if (!(graph as LGraph).getNodeById) { + graph.links = (graph as SerializedGraph).links.filter((l) => !!l); + } + } + if (!data.patchedNodes.length && !data.deletedLinks.length) { + return { + hasBadLinks: false, + fixed: false, + graph, + patched: data.patchedNodes.length, + deleted: data.deletedLinks.length, + }; + } + !silent && + logger.log( + `${fix ? "Made" : "Would make"} ${data.patchedNodes.length || "no"} node link patches, and ${ + data.deletedLinks.length || "no" + } stale link removals.`, + ); + + let hasBadLinks: boolean = !!(data.patchedNodes.length || data.deletedLinks.length); + // If we're fixing, then let's run it again to see if there are no more bad links. + if (fix && !silent) { + const rerun = fixBadLinks(graph, false, true); + hasBadLinks = rerun.hasBadLinks; + } + + return { + hasBadLinks, + fixed: !!hasBadLinks && fix, + graph, + patched: data.patchedNodes.length, + deleted: data.deletedLinks.length, + }; +} diff --git a/rgthree-comfy/src_web/common/media/rgthree.svg b/rgthree-comfy/src_web/common/media/rgthree.svg new file mode 100644 index 0000000000000000000000000000000000000000..85e22fe25d0491149e95e48b1d02f7cd7cc42090 --- /dev/null +++ b/rgthree-comfy/src_web/common/media/rgthree.svg @@ -0,0 +1,7 @@ + + + + + + + \ No newline at end of file diff --git a/rgthree-comfy/src_web/common/media/svgs.ts b/rgthree-comfy/src_web/common/media/svgs.ts new file mode 100644 index 0000000000000000000000000000000000000000..7f76c6c44f00f7cdc861b0549bad62ffc41071d5 --- /dev/null +++ b/rgthree-comfy/src_web/common/media/svgs.ts @@ -0,0 +1,186 @@ +import { createElement as $el } from "../utils_dom.js"; + +// Some svg repo : https://www.svgrepo.com/svg/326731/open-outline + +export const logoRgthree = ``; + +export const github = ``; + +export const iconStarFilled = ` + + `; + +export const iconReplace = ` + + + + + `; + +export const iconNode = ` + + + `; + +export const iconGear = ` + + `; + +export const checkmark = ` + + + + `; + +export const logoCivitai = ` + + + + + + + + + + `; + +export const iconOutLink = ` + + `; + +export const link = ` + + `; + +export const pencil = ` + + `; + +export const dotdotdot = ` + + + +`; + +export const models = ` + + + + +`; + +/** https://www.svgrepo.com/svg/402308/pencil */ +export const pencilColored = ` + + + + + + + + + `; + +/** https://www.svgrepo.com/svg/395640/save */ +export const diskColored = ` + + + + + + + +`; + +/** https://www.svgrepo.com/svg/229838/folder */ +export const folderColored = ` + + +`; + +export const modelsColored = ` + + + + +`; + +export const legoBlocksColored = ` + + + + + + + + + + + + +`; + +export const legoBlockColored = ` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +`; + +export const gearColored = ` + + + + +`; + +export function $svg(markup: string, attrs: { [key: string]: string }) { + if (!markup.match(/^\s* Promise> = new Map(); + + private handleWindowPointerDownBound = this.handleWindowPointerDown.bind(this); + + constructor(options: MenuOption[]) { + this.setOptions(options); + this.element.addEventListener('pointerup', async (e) => { + const target = getClosestOrSelf(e.target as HTMLElement, "[data-callback],menu"); + if (e.which !== 1) { + return; + } + const callback = target?.dataset?.['callback']; + if (callback) { + const halt = await this.callbacks.get(callback)?.(e); + if (halt !== false) { + this.close(); + } + } + e.preventDefault(); + e.stopPropagation(); + e.stopImmediatePropagation(); + }); + } + + setOptions(options: MenuOption[]) { + for (const option of options) { + if (option.type === 'title') { + this.element.appendChild($el(`li`, { + html: option.label + })); + } else { + const id = generateId(8); + this.callbacks.set(id, async (e: PointerEvent) => { return option?.callback?.(e); }); + this.element.appendChild($el(`li[role="button"][data-callback="${id}"]`, { + html: option.label + })); + } + } + } + + toElement() { + return this.element; + } + + async open(e: PointerEvent) { + const parent = (e.target as HTMLElement).closest('div,dialog,body') as HTMLElement + parent.appendChild(this.element); + setAttributes(this.element, { + style: { + left: `${e.clientX + 16}px`, + top: `${e.clientY - 16}px`, + } + }); + this.element.setAttribute('state', 'measuring-open'); + await wait(16); + const rect = this.element.getBoundingClientRect(); + if (rect.right > window.innerWidth) { + this.element.style.left = `${e.clientX - rect.width - 16}px`; + await wait(16); + } + this.element.setAttribute('state', 'open'); + setTimeout(() => { + window.addEventListener('pointerdown', this.handleWindowPointerDownBound); + }); + } + + handleWindowPointerDown(e:PointerEvent) { + if (!this.element.contains(e.target as HTMLElement)) { + this.close(); + } + } + + async close() { + window.removeEventListener('pointerdown', this.handleWindowPointerDownBound); + this.element.setAttribute('state', 'measuring-closed'); + await wait(16); + this.element.setAttribute('state', 'closed'); + this.element.remove(); + } + + isOpen() { + return (this.element.getAttribute('state') || '').includes('open'); + } + +} + +type MenuOption = { + label: string; + type?: 'title'|'item'|'separator'; + callback?: (e: PointerEvent) => void; +} + +type MenuButtonOptions = { + icon: string; + options: MenuOption[]; +} + +export class MenuButton { + + private options: MenuButtonOptions; + private menu: Menu; + + private element: HTMLButtonElement = $el('button.rgthree-button[data-action="open-menu"]') + + constructor(options: MenuButtonOptions) { + this.options = options; + this.element.innerHTML = options.icon; + this.menu = new Menu(options.options); + + this.element.addEventListener('pointerdown', (e) => { + if (!this.menu.isOpen()) { + this.menu.open(e); + } + }); + } + + toElement() { + return this.element; + } + +} \ No newline at end of file diff --git a/rgthree-comfy/src_web/common/model_info_service.ts b/rgthree-comfy/src_web/common/model_info_service.ts new file mode 100644 index 0000000000000000000000000000000000000000..306297e98ee3d9dc1c52ddb2ba05c1bfaced65f5 --- /dev/null +++ b/rgthree-comfy/src_web/common/model_info_service.ts @@ -0,0 +1,74 @@ +import type { RgthreeModelInfo } from "typings/rgthree.js"; +import { rgthreeApi } from "./rgthree_api.js"; +import { api } from "scripts/api.js"; + +/** + * A singleton service to fetch and cache model infos from rgthree-comfy. + */ +class ModelInfoService extends EventTarget { + private readonly loraToInfo = new Map(); + + constructor() { + super(); + api.addEventListener( + "rgthree-refreshed-lora-info", + this.handleLoraAsyncUpdate.bind(this) as EventListener, + ); + } + + /** + * Single point to set data into the info cache, and fire an event. Note, this doesn't determine + * if the data is actually different. + */ + private setFreshLoraData(file: string, info: RgthreeModelInfo) { + this.loraToInfo.set(file, info); + this.dispatchEvent( + new CustomEvent("rgthree-model-service-lora-details", { detail: { lora: info } }), + ); + } + + async getLora(file: string, refresh = false, light = false) { + if (this.loraToInfo.has(file) && !refresh) { + return this.loraToInfo.get(file)!; + } + return this.fetchLora(file, refresh, light); + } + + async fetchLora(file: string, refresh = false, light = false) { + let info = null; + if (!refresh) { + info = await rgthreeApi.getLorasInfo(file, light); + } else { + info = await rgthreeApi.refreshLorasInfo(file); + } + if (!light) { + this.loraToInfo.set(file, info); + } + return info; + } + + async refreshLora(file: string) { + return this.fetchLora(file, true); + } + + async clearLoraFetchedData(file: string) { + await rgthreeApi.clearLorasInfo(file); + this.loraToInfo.delete(file); + return null; + } + + async saveLoraPartial(file: string, data: Partial) { + let info = await rgthreeApi.saveLoraInfo(file, data); + this.loraToInfo.set(file, info); + return info; + } + + private handleLoraAsyncUpdate(event: CustomEvent<{ data: RgthreeModelInfo }>) { + const info = event.detail?.data as RgthreeModelInfo; + if (info?.file) { + this.setFreshLoraData(info.file, info); + } + } +} + +export const SERVICE = new ModelInfoService(); diff --git a/rgthree-comfy/src_web/common/progress_bar.ts b/rgthree-comfy/src_web/common/progress_bar.ts new file mode 100644 index 0000000000000000000000000000000000000000..15de18fac26262a842709ac2f49e6a8e17778b80 --- /dev/null +++ b/rgthree-comfy/src_web/common/progress_bar.ts @@ -0,0 +1,218 @@ +/** + * Progress bar web component. + */ + +import { SERVICE as PROMPT_SERVICE, type PromptExecution } from "rgthree/common/prompt_service.js"; +import { createElement } from "./utils_dom.js"; + +/** + * The progress bar web component. + */ +export class RgthreeProgressBar extends HTMLElement { + static NAME = "rgthree-progress-bar"; + + static create(): RgthreeProgressBar { + return document.createElement(RgthreeProgressBar.NAME) as RgthreeProgressBar; + } + + private shadow: ShadowRoot | null = null; + private progressNodesEl!: HTMLDivElement; + private progressStepsEl!: HTMLDivElement; + private progressTextEl!: HTMLSpanElement; + + private currentPromptExecution: PromptExecution | null = null; + + private readonly onProgressUpdateBound = this.onProgressUpdate.bind(this); + + private connected: boolean = false; + + /** The currentNodeId so outside callers can see what we're currently executing against. */ + get currentNodeId() { + const prompt = this.currentPromptExecution; + const nodeId = prompt?.errorDetails?.node_id || prompt?.currentlyExecuting?.nodeId; + return nodeId || null; + } + + constructor() { + super(); + } + + private onProgressUpdate(e: CustomEvent<{ queue: number; prompt: PromptExecution }>) { + if (!this.connected) return; + + const prompt = e.detail.prompt; + this.currentPromptExecution = prompt; + + if (prompt?.errorDetails) { + let progressText = `${prompt.errorDetails?.exception_type} ${ + prompt.errorDetails?.node_id || "" + } ${prompt.errorDetails?.node_type || ""}`; + this.progressTextEl.innerText = progressText; + this.progressNodesEl.classList.add("-error"); + this.progressStepsEl.classList.add("-error"); + return; + } + if (prompt?.currentlyExecuting) { + this.progressNodesEl.classList.remove("-error"); + this.progressStepsEl.classList.remove("-error"); + + const current = prompt?.currentlyExecuting; + + let progressText = `(${e.detail.queue}) `; + + // Sometimes we may get status updates for a workflow that was already running. In that case + // we don't know totalNodes. + if (!prompt.totalNodes) { + progressText += `??%`; + this.progressNodesEl.style.width = `0%`; + } else { + const percent = (prompt.executedNodeIds.length / prompt.totalNodes) * 100; + this.progressNodesEl.style.width = `${Math.max(2, percent)}%`; + // progressText += `Node ${prompt.executedNodeIds.length + 1} of ${prompt.totalNodes || "?"}`; + progressText += `${Math.round(percent)}%`; + } + + let nodeLabel = current.nodeLabel?.trim(); + let stepsLabel = ""; + if (current.step != null && current.maxSteps) { + const percent = (current.step / current.maxSteps) * 100; + this.progressStepsEl.style.width = `${percent}%`; + // stepsLabel += `Step ${current.step} of ${current.maxSteps}`; + if (current.pass > 1 || current.maxPasses != null) { + stepsLabel += `#${current.pass}`; + if (current.maxPasses && current.maxPasses > 0) { + stepsLabel += `/${current.maxPasses}`; + } + stepsLabel += ` - `; + } + stepsLabel += `${Math.round(percent)}%`; + } + + if (nodeLabel || stepsLabel) { + progressText += ` - ${nodeLabel || "???"}${stepsLabel ? ` (${stepsLabel})` : ""}`; + } + if (!stepsLabel) { + this.progressStepsEl.style.width = `0%`; + } + this.progressTextEl.innerText = progressText; + } else { + if (e?.detail.queue) { + this.progressTextEl.innerText = `(${e.detail.queue}) Running... in another tab`; + } else { + this.progressTextEl.innerText = "Idle"; + } + this.progressNodesEl.style.width = `0%`; + this.progressStepsEl.style.width = `0%`; + } + } + + connectedCallback() { + if (!this.connected) { + PROMPT_SERVICE.addEventListener( + "progress-update", + this.onProgressUpdateBound as EventListener, + ); + this.connected = true; + } + // We were already connected, so we just need to reset. + if (this.shadow) { + this.progressTextEl.innerText = "Idle"; + this.progressNodesEl.style.width = `0%`; + this.progressStepsEl.style.width = `0%`; + return; + } + + this.shadow = this.attachShadow({ mode: "open" }); + const sheet = new CSSStyleSheet(); + sheet.replaceSync(` + + :host { + position: relative; + overflow: hidden; + box-sizing: border-box; + background: var(--rgthree-progress-bg-color); + --rgthree-progress-bg-color: rgba(23, 23, 23, 0.9); + --rgthree-progress-nodes-bg-color: rgb(0, 128, 0); + --rgthree-progress-steps-bg-color: rgb(0, 128, 0); + --rgthree-progress-error-bg-color: rgb(128, 0, 0); + --rgthree-progress-text-color: #fff; + } + :host * { + box-sizing: inherit; + } + + :host > div.bar { + background: var(--rgthree-progress-nodes-bg-color); + position: absolute; + left: 0; + top: 0; + width: 0%; + height: 50%; + z-index: 1; + transition: width 50ms ease-in-out; + } + :host > div.bar + div.bar { + background: var(--rgthree-progress-steps-bg-color); + top: 50%; + height: 50%; + z-index: 2; + } + :host > div.bar.-error { + background: var(--rgthree-progress-error-bg-color); + } + + :host > .overlay { + position: absolute; + left: 0; + top: 0; + width: 100%; + height: 100%; + z-index: 5; + background: linear-gradient(to bottom, rgba(255,255,255,0.25), rgba(0,0,0,0.25)); + mix-blend-mode: overlay; + } + + :host > span { + position: relative; + z-index: 4; + text-align: left; + font-size: inherit; + height: 100%; + font-family: sans-serif; + text-shadow: 1px 1px 0px #000; + display: flex; + flex-direction: row; + padding: 0 6px; + align-items: center; + justify-content: start; + color: var(--rgthree-progress-text-color); + text-shadow: black 0px 0px 2px; + } + + :host > div.bar[style*="width: 0%"]:first-child, + :host > div.bar[style*="width:0%"]:first-child { + height: 0%; + } + :host > div.bar[style*="width: 0%"]:first-child + div, + :host > div.bar[style*="width:0%"]:first-child + div { + bottom: 0%; + } + `); + this.shadow.adoptedStyleSheets = [sheet]; + + const overlayEl = createElement(`div.overlay[part="overlay"]`, { parent: this.shadow }); + this.progressNodesEl = createElement(`div.bar[part="progress-nodes"]`, { parent: this.shadow }); + this.progressStepsEl = createElement(`div.bar[part="progress-steps"]`, { parent: this.shadow }); + this.progressTextEl = createElement(`span[part="text"]`, { text: "Idle", parent: this.shadow }); + } + + disconnectedCallback() { + this.connected = false; + PROMPT_SERVICE.removeEventListener( + "progress-update", + this.onProgressUpdateBound as EventListener, + ); + } +} + +customElements.define(RgthreeProgressBar.NAME, RgthreeProgressBar); diff --git a/rgthree-comfy/src_web/common/prompt_service.ts b/rgthree-comfy/src_web/common/prompt_service.ts new file mode 100644 index 0000000000000000000000000000000000000000..381f5bc68e4970e22f4d84b3fc3cde9e74abee99 --- /dev/null +++ b/rgthree-comfy/src_web/common/prompt_service.ts @@ -0,0 +1,274 @@ +import type { + ComfyApiEventDetailCached, + ComfyApiEventDetailError, + ComfyApiEventDetailExecuted, + ComfyApiEventDetailExecuting, + ComfyApiEventDetailExecutionStart, + ComfyApiEventDetailProgress, + ComfyApiEventDetailStatus, + ComfyApiFormat, + ComfyApiPrompt, +} from "typings/comfy.js"; +import { api } from "scripts/api.js"; +import type { LGraph as TLGraph, LGraphCanvas as TLGraphCanvas } from "typings/litegraph.js"; +import { Resolver, getResolver } from "./shared_utils.js"; + +/** + * Wraps general data of a prompt's execution. + */ +export class PromptExecution { + id: string; + promptApi: ComfyApiFormat | null = null; + executedNodeIds: string[] = []; + totalNodes: number = 0; + currentlyExecuting: { + nodeId: string; + nodeLabel?: string; + step?: number; + maxSteps?: number; + /** The current pass, for nodes with multiple progress passes. */ + pass: number; + /** + * The max num of passes. Can be calculated for some nodes, or set to -1 when known there will + * be multiple passes, but the number cannot be calculated. + */ + maxPasses?: number; + } | null = null; + errorDetails: any | null = null; + + apiPrompt: Resolver = getResolver(); + + constructor(id: string) { + this.id = id; + } + + /** + * Sets the prompt and prompt-related data. This can technically come in lazily, like if the web + * socket fires the 'execution-start' event before we actually get a response back from the + * initial prompt call. + */ + setPrompt(prompt: ComfyApiPrompt) { + this.promptApi = prompt.output; + this.totalNodes = Object.keys(this.promptApi).length; + this.apiPrompt.resolve(null); + } + + getApiNode(nodeId: string | number) { + return this.promptApi?.[String(nodeId)] || null; + } + + private getNodeLabel(nodeId: string | number) { + const apiNode = this.getApiNode(nodeId); + let label = apiNode?._meta?.title || apiNode?.class_type || undefined; + if (!label) { + const graphNode = this.maybeGetComfyGraph()?.getNodeById(Number(nodeId)); + label = graphNode?.title || graphNode?.type || undefined; + } + return label; + } + + /** + * Updates the execution data depending on the passed data, fed from api events. + */ + executing(nodeId: string | null, step?: number, maxSteps?: number) { + if (nodeId == null) { + // We're done, any left over nodes must be skipped... + this.currentlyExecuting = null; + return; + } + if (this.currentlyExecuting?.nodeId !== nodeId) { + if (this.currentlyExecuting != null) { + this.executedNodeIds.push(nodeId); + } + this.currentlyExecuting = { nodeId, nodeLabel: this.getNodeLabel(nodeId), pass: 0 }; + // We'll see if we're known node for multiple passes, that will come in as generic 'progress' + // updates from the api. If we're known to have multiple passes, then we'll pre-set data to + // allow the progress bar to handle intial rendering. If we're not, that's OK, the data will + // be shown with the second pass. + this.apiPrompt.promise.then(() => { + // If we execute with a null node id and clear the currently executing, then we can just + // move on. This seems to only happen with a super-fast execution (like, just seed node + // and display any for testing). + if (this.currentlyExecuting == null) { + return; + } + const apiNode = this.getApiNode(nodeId); + if (!this.currentlyExecuting.nodeLabel) { + this.currentlyExecuting.nodeLabel = this.getNodeLabel(nodeId); + } + if (apiNode?.class_type === "UltimateSDUpscale") { + // From what I can tell, UltimateSDUpscale, does an initial pass that isn't actually a + // tile. It seems to always be 4 steps... We'll start our pass at -1, so this prepass is + // "0" and "1" will start with the first tile. This way, a user knows they have 4 tiles, + // know this pass counter will go to 4 (and not 5). Also, we cannot calculate maxPasses + // for 'UltimateSDUpscale' :( + this.currentlyExecuting.pass--; + this.currentlyExecuting.maxPasses = -1; + } else if (apiNode?.class_type === "IterativeImageUpscale") { + this.currentlyExecuting.maxPasses = (apiNode?.inputs["steps"] as number) ?? -1; + } + }); + } + if (step != null) { + // If we haven't had any stpes before, or the passes step is lower than the previous, then + // increase the passes. + if (!this.currentlyExecuting!.step || step < this.currentlyExecuting!.step) { + this.currentlyExecuting!.pass!++; + } + this.currentlyExecuting!.step = step; + this.currentlyExecuting!.maxSteps = maxSteps; + } + } + + /** + * If there's an error, we add the details. + */ + error(details: any) { + this.errorDetails = details; + } + + private maybeGetComfyGraph(): TLGraph | null { + return ((window as any)?.app?.graph as TLGraph) || null; + } +} + +/** + * A singleton service that wraps the Comfy API and simplifies the event data being fired. + */ +class PromptService extends EventTarget { + promptsMap: Map = new Map(); + currentExecution: PromptExecution | null = null; + lastQueueRemaining = 0; + + constructor(api: any) { + super(); + const that = this; + + // Patch the queuePrompt method so we can capture new data going through. + const queuePrompt = api.queuePrompt; + api.queuePrompt = async function (num: number, prompt: ComfyApiPrompt) { + let response; + try { + response = await queuePrompt.apply(api, [...arguments]); + } catch (e) { + const promptExecution = that.getOrMakePrompt("error"); + promptExecution.error({ exception_type: "Unknown." }); + // console.log("ERROR QUEUE PROMPT", response, arguments); + throw e; + } + // console.log("QUEUE PROMPT", response, arguments); + const promptExecution = that.getOrMakePrompt(response.prompt_id); + promptExecution.setPrompt(prompt); + if (!that.currentExecution) { + that.currentExecution = promptExecution; + } + that.promptsMap.set(response.prompt_id, promptExecution); + that.dispatchEvent( + new CustomEvent("queue-prompt", { + detail: { + prompt: promptExecution, + }, + }), + ); + return response; + }; + + api.addEventListener("status", (e: CustomEvent) => { + // console.log("status", JSON.stringify(e.detail)); + // Sometimes a status message is fired when the app loades w/o any details. + if (!e.detail?.exec_info) return; + this.lastQueueRemaining = e.detail.exec_info.queue_remaining; + this.dispatchProgressUpdate(); + }); + + api.addEventListener("execution_start", (e: CustomEvent) => { + // console.log("execution_start", JSON.stringify(e.detail)); + if (!this.promptsMap.has(e.detail.prompt_id)) { + console.warn("'execution_start' fired before prompt was made."); + } + const prompt = this.getOrMakePrompt(e.detail.prompt_id); + this.currentExecution = prompt; + this.dispatchProgressUpdate(); + }); + + api.addEventListener("executing", (e: CustomEvent) => { + // console.log("executing", JSON.stringify(e.detail)); + if (!this.currentExecution) { + this.currentExecution = this.getOrMakePrompt("unknown"); + console.warn("'executing' fired before prompt was made."); + } + this.currentExecution.executing(e.detail); + this.dispatchProgressUpdate(); + if (e.detail == null) { + this.currentExecution = null; + } + }); + + api.addEventListener("progress", (e: CustomEvent) => { + // console.log("progress", JSON.stringify(e.detail)); + if (!this.currentExecution) { + this.currentExecution = this.getOrMakePrompt(e.detail.prompt_id); + console.warn("'progress' fired before prompt was made."); + } + this.currentExecution.executing(e.detail.node, e.detail.value, e.detail.max); + this.dispatchProgressUpdate(); + }); + + api.addEventListener("execution_cached", (e: CustomEvent) => { + // console.log("execution_cached", JSON.stringify(e.detail)); + if (!this.currentExecution) { + this.currentExecution = this.getOrMakePrompt(e.detail.prompt_id); + console.warn("'execution_cached' fired before prompt was made."); + } + for (const cached of e.detail.nodes) { + this.currentExecution.executing(cached); + } + this.dispatchProgressUpdate(); + }); + + api.addEventListener("executed", (e: CustomEvent) => { + // console.log("executed", JSON.stringify(e.detail)); + if (!this.currentExecution) { + this.currentExecution = this.getOrMakePrompt(e.detail.prompt_id); + console.warn("'executed' fired before prompt was made."); + } + }); + + api.addEventListener("execution_error", (e: CustomEvent) => { + // console.log("execution_error", e.detail); + if (!this.currentExecution) { + this.currentExecution = this.getOrMakePrompt(e.detail.prompt_id); + console.warn("'execution_error' fired before prompt was made."); + } + this.currentExecution?.error(e.detail); + this.dispatchProgressUpdate(); + }); + } + + /** A helper method, since we extend/override api.queuePrompt above anyway. */ + async queuePrompt(prompt: ComfyApiPrompt) { + return await api.queuePrompt(-1, prompt); + } + + dispatchProgressUpdate() { + this.dispatchEvent( + new CustomEvent("progress-update", { + detail: { + queue: this.lastQueueRemaining, + prompt: this.currentExecution, + }, + }), + ); + } + + getOrMakePrompt(id: string) { + let prompt = this.promptsMap.get(id); + if (!prompt) { + prompt = new PromptExecution(id); + this.promptsMap.set(id, prompt); + } + return prompt; + } +} + +export const SERVICE = new PromptService(api); diff --git a/rgthree-comfy/src_web/common/rgthree_api.ts b/rgthree-comfy/src_web/common/rgthree_api.ts new file mode 100644 index 0000000000000000000000000000000000000000..c9b59b849dffb549942c002dfc8036b0b67db5a1 --- /dev/null +++ b/rgthree-comfy/src_web/common/rgthree_api.ts @@ -0,0 +1,99 @@ +import type { RgthreeModelInfo } from "typings/rgthree.js"; + +class RgthreeApi { + private baseUrl: string; + getCheckpointsPromise: Promise | null = null; + getSamplersPromise: Promise | null = null; + getSchedulersPromise: Promise | null = null; + getLorasPromise: Promise | null = null; + getWorkflowsPromise: Promise | null = null; + + constructor(baseUrl?: string) { + this.baseUrl = baseUrl || "./rgthree/api"; + } + + apiURL(route: string) { + return `${this.baseUrl}${route}`; + } + + fetchApi(route: string, options?: RequestInit) { + return fetch(this.apiURL(route), options); + } + + async fetchJson(route: string, options?: RequestInit) { + const r = await this.fetchApi(route, options); + return await r.json(); + } + + async postJson(route: string, json: any) { + const body = new FormData(); + body.append("json", JSON.stringify(json)); + return await rgthreeApi.fetchJson(route, { method: "POST", body }); + } + + getLoras(force = false) { + if (!this.getLorasPromise || force) { + this.getLorasPromise = this.fetchJson("/loras", { cache: "no-store" }); + } + return this.getLorasPromise; + } + + async fetchApiJsonOrNull(route: string, options?: RequestInit) { + const response = await this.fetchJson(route, options); + if (response.status === 200 && response.data) { + return (response.data as T) || null; + } + return null; + } + + /** + * Fetches the lora information. + * + * @param light Whether or not to generate a json file if there isn't one. This isn't necessary if + * we're just checking for values, but is more necessary when opening an info dialog. + */ + async getLorasInfo(lora: string, light?: boolean): Promise; + async getLorasInfo(light?: boolean): Promise; + async getLorasInfo(...args: any) { + const params = new URLSearchParams(); + const isSingleLora = typeof args[0] == 'string'; + if (isSingleLora) { + params.set("file", args[0]); + } + params.set("light", (isSingleLora ? args[1] : args[0]) === false ? '0' : '1'); + const path = `/loras/info?` + params.toString(); + return await this.fetchApiJsonOrNull(path); + } + + async refreshLorasInfo(file: string): Promise; + async refreshLorasInfo(): Promise; + async refreshLorasInfo(file?: string) { + const path = `/loras/info/refresh` + (file ? `?file=${encodeURIComponent(file)}` : ''); + const infos = await this.fetchApiJsonOrNull(path); + return infos; + } + + async clearLorasInfo(file?: string): Promise { + const path = `/loras/info/clear` + (file ? `?file=${encodeURIComponent(file)}` : ''); + await this.fetchApiJsonOrNull(path); + return; + } + + /** + * Saves partial data sending it to the backend.. + */ + async saveLoraInfo( + lora: string, + data: Partial, + ): Promise { + const body = new FormData(); + body.append("json", JSON.stringify(data)); + return await this.fetchApiJsonOrNull( + `/loras/info?file=${encodeURIComponent(lora)}`, + { cache: "no-store", method: "POST", body }, + ); + } + +} + +export const rgthreeApi = new RgthreeApi(); diff --git a/rgthree-comfy/src_web/common/shared_utils.ts b/rgthree-comfy/src_web/common/shared_utils.ts new file mode 100644 index 0000000000000000000000000000000000000000..1e3ed0279b287a50232d094a999fd8650dc6e551 --- /dev/null +++ b/rgthree-comfy/src_web/common/shared_utils.ts @@ -0,0 +1,181 @@ +/** + * @fileoverview + * A bunch of shared utils that can be used in ComfyUI, as well as in any single-HTML pages. + */ + +export type Resolver = { + id: string; + completed: boolean; + resolved: boolean; + rejected: boolean; + promise: Promise; + resolve: (data: T) => void; + reject: () => void; + timeout: number | null; + // A caller property to store a defer timeout on. + deferredTimeout?: number | null; + deferredData?: any; +}; + +/** + * Returns a new `Resolver` type that allows creating a "disconnected" `Promise` that can be + * returned and resolved separately. + */ +export function getResolver(timeout: number = 5000): Resolver { + const resolver: Partial> = {}; + resolver.id = generateId(8); + resolver.completed = false; + resolver.resolved = false; + resolver.rejected = false; + resolver.promise = new Promise((resolve, reject) => { + resolver.reject = () => { + resolver.completed = true; + resolver.rejected = true; + reject(); + }; + resolver.resolve = (data: T) => { + resolver.completed = true; + resolver.resolved = true; + resolve(data); + }; + }); + resolver.timeout = setTimeout(() => { + if (!resolver.completed) { + resolver.reject!(); + } + }, timeout); + return resolver as Resolver; +} + +/** The WeakMap for debounced functions. */ +const DEBOUNCE_FN_TO_PROMISE: WeakMap> = new WeakMap(); + +/** + * Debounces a function call so it is only called once in the initially provided ms even if asked + * to be called multiple times within that period. + */ +export function debounce(fn: Function, ms = 64) { + if (!DEBOUNCE_FN_TO_PROMISE.get(fn)) { + DEBOUNCE_FN_TO_PROMISE.set( + fn, + wait(ms).then(() => { + DEBOUNCE_FN_TO_PROMISE.delete(fn); + fn(); + }), + ); + } + return DEBOUNCE_FN_TO_PROMISE.get(fn); +} + +/** Waits a certain number of ms, as a `Promise.` */ +export function wait(ms = 16): Promise { + // Special logic, if we're waiting 16ms, then trigger on next frame. + if (ms === 16) { + return new Promise((resolve) => { + requestAnimationFrame(() => { + resolve(); + }); + }); + } + return new Promise((resolve) => { + setTimeout(() => { + resolve(); + }, ms); + }); +} + +function dec2hex(dec: number) { + return dec.toString(16).padStart(2, "0"); +} + +/** Generates an unique id of a specific length. */ +export function generateId(length: number) { + const arr = new Uint8Array(length / 2); + crypto.getRandomValues(arr); + return Array.from(arr, dec2hex).join(""); +} + +/** + * Returns the deep value of an object given a dot-delimited key. + */ +export function getObjectValue(obj: any, objKey: string, def?: any) { + if (!obj || !objKey) return def; + + const keys = objKey.split("."); + const key = keys.shift()!; + const found = obj[key]; + if (keys.length) { + return getObjectValue(found, keys.join("."), def); + } + return found; +} + +/** + * Sets the deep value of an object given a dot-delimited key. + * + * By default, missing objects will be created while settng the path. If `createMissingObjects` is + * set to false, then the setting will be abandoned if the key path is missing an intermediate + * value. For example: + * + * setObjectValue({a: {z: false}}, 'a.b.c', true); // {a: {z: false, b: {c: true } } } + * setObjectValue({a: {z: false}}, 'a.b.c', true, false); // {a: {z: false}} + * + */ +export function setObjectValue(obj: any, objKey: string, value: any, createMissingObjects = true) { + if (!obj || !objKey) return obj; + + const keys = objKey.split("."); + const key = keys.shift()!; + if (obj[key] === undefined) { + if (!createMissingObjects) { + return; + } + obj[key] = {}; + } + if (!keys.length) { + obj[key] = value; + } else { + if (typeof obj[key] != "object") { + obj[key] = {}; + } + setObjectValue(obj[key], keys.join("."), value, createMissingObjects); + } + return obj; +} + +/** + * Moves an item in an array (by item or its index) to another index. + */ +export function moveArrayItem(arr: T[], itemOrFrom: T | number, to: number) { + const from = typeof itemOrFrom === "number" ? itemOrFrom : arr.indexOf(itemOrFrom); + arr.splice(to, 0, arr.splice(from, 1)[0]!); +} + +/** + * Moves an item in an array (by item or its index) to another index. + */ +export function removeArrayItem(arr: T[], itemOrIndex: T | number) { + const index = typeof itemOrIndex === "number" ? itemOrIndex : arr.indexOf(itemOrIndex); + arr.splice(index, 1); +} + +/** + * Injects CSS into the page with a promise when complete. + */ +export function injectCss(href: string): Promise { + if (document.querySelector(`link[href^="${href}"]`)) { + return Promise.resolve(); + } + return new Promise((resolve) => { + const link = document.createElement("link"); + link.setAttribute("rel", "stylesheet"); + link.setAttribute("type", "text/css"); + const timeout = setTimeout(resolve, 1000); + link.addEventListener("load", (e) => { + clearInterval(timeout); + resolve(); + }); + link.href = href; + document.head.appendChild(link); + }); +} diff --git a/rgthree-comfy/src_web/common/utils_dom.ts b/rgthree-comfy/src_web/common/utils_dom.ts new file mode 100644 index 0000000000000000000000000000000000000000..5e52cc98abdf0ab078666f4a428da58c00a3add9 --- /dev/null +++ b/rgthree-comfy/src_web/common/utils_dom.ts @@ -0,0 +1,350 @@ +/** + * Various dom manipulation utils that have followed me around. + */ +const DIRECT_ATTRIBUTE_MAP: {[name: string]: string} = { + cellpadding: 'cellPadding', + cellspacing: 'cellSpacing', + colspan: 'colSpan', + frameborder: 'frameBorder', + height: 'height', + maxlength: 'maxLength', + nonce: 'nonce', + role: 'role', + rowspan: 'rowSpan', + type: 'type', + usemap: 'useMap', + valign: 'vAlign', + width: 'width', +}; + +const RGX_NUMERIC_STYLE_UNIT = 'px'; +const RGX_NUMERIC_STYLE = /^((max|min)?(width|height)|margin|padding|(margin|padding)?(left|top|bottom|right)|fontsize|borderwidth)$/i; +const RGX_DEFAULT_VALUE_PROP = /input|textarea|select/i; + + +function localAssertNotFalsy(input?: T|null, errorMsg = `Input is not of type.`) : T { + if (input == null) { + throw new Error(errorMsg); + } + return input; +} + + +const RGX_STRING_VALID = '[a-z0-9_-]'; +const RGX_TAG = new RegExp(`^([a-z]${RGX_STRING_VALID}*)(\\.|\\[|\\#|$)`, 'i'); +const RGX_ATTR_ID = new RegExp(`#(${RGX_STRING_VALID}+)`, 'gi'); +const RGX_ATTR_CLASS = new RegExp(`(^|\\S)\\.([a-z0-9_\\-\\.]+)`, 'gi'); +const RGX_STRING_CONTENT_TO_SQUARES = '(.*?)(\\[|\\])'; +const RGX_ATTRS_MAYBE_OPEN = new RegExp(`\\[${RGX_STRING_CONTENT_TO_SQUARES}`, 'gi'); +const RGX_ATTRS_FOLLOW_OPEN = new RegExp(`^${RGX_STRING_CONTENT_TO_SQUARES}`, 'gi'); + +export function query(selectors: K, parent?: HTMLElement|Document): Array; +export function query(selectors: K, parent?: HTMLElement|Document): Array; +export function query(selectors: K, parent?: HTMLElement|Document): Array; +export function query(selectors: string, parent?: HTMLElement|Document): Array; +export function query(selectors: string, parent: HTMLElement|Document = document) { + return Array.from(parent.querySelectorAll(selectors)).filter(n => !!n); +} + +export function queryOne(selectors: K, parent?: HTMLElement|Document): HTMLElementTagNameMap[K] | null; +export function queryOne(selectors: K, parent?: HTMLElement|Document): SVGElementTagNameMap[K] | null; +export function queryOne(selectors: K, parent?: HTMLElement|Document): MathMLElementTagNameMap[K] | null; +export function queryOne(selectors: string, parent?: HTMLElement|Document): T | null; +export function queryOne(selectors: string, parent: HTMLElement|Document = document) { + return parent.querySelector(selectors) ?? null; +} + +export function createText(text: string) { + return document.createTextNode(text); +} + +export function getClosestOrSelf(element: EventTarget|HTMLElement|null, query: string) : HTMLElement|null { + const el = (element as HTMLElement); + return (el?.closest && (el.matches(query) && el || el.closest(query)) as HTMLElement) || null; +} + +type Attrs = { + [name: string]: any; +}; + +export function createElement(selectorOrMarkup: string, attrs?: Attrs) { + const frag = getHtmlFragment(selectorOrMarkup); + let element = frag?.firstElementChild as HTMLElement; + let selector = ""; + if (!element) { + selector = selectorOrMarkup.replace(/[\r\n]\s*/g, ""); + const tag = getSelectorTag(selector) || "div"; + element = document.createElement(tag); + selector = selector.replace(RGX_TAG, "$2"); + // Turn id and classname into [attr]s that can be nested + selector = selector.replace(RGX_ATTR_ID, '[id="$1"]'); + selector = selector.replace( + RGX_ATTR_CLASS, + (match, p1, p2) => `${p1}[class="${p2.replace(/\./g, " ")}"]`, + ); + } + + const selectorAttrs = getSelectorAttributes(selector); + if (selectorAttrs) { + for (const attr of selectorAttrs) { + let matches = attr.substring(1, attr.length - 1).split("="); + let key = localAssertNotFalsy(matches.shift()); + let value: string = matches.join("="); + if (value === undefined) { + setAttribute(element, key, true); + } else { + value = value.replace(/^['"](.*)['"]$/, "$1"); + setAttribute(element, key, value); + } + } + } + if (attrs) { + setAttributes(element, attrs); + } + return element as T; +} + +function getSelectorTag(str: string) { + return tryMatch(str, RGX_TAG); +} + +function getSelectorAttributes(selector: string) { + RGX_ATTRS_MAYBE_OPEN.lastIndex = 0; + let attrs: string[] = []; + let result; + while (result = RGX_ATTRS_MAYBE_OPEN.exec(selector)) { + let attr = result[0]; + if (attr.endsWith(']')) { + attrs.push(attr); + } else { + attr = result[0] + + getOpenAttributesRecursive(selector.substr(RGX_ATTRS_MAYBE_OPEN.lastIndex), 2); + RGX_ATTRS_MAYBE_OPEN.lastIndex += (attr.length - result[0].length); + attrs.push(attr); + } + } + return attrs; +} + + +function getOpenAttributesRecursive(selectorSubstring: string, openCount: number) { + let matches = selectorSubstring.match(RGX_ATTRS_FOLLOW_OPEN); + let result = ''; + if (matches && matches.length) { + result = matches[0]; + openCount += result.endsWith(']') ? -1 : 1; + if (openCount > 0) { + result += getOpenAttributesRecursive(selectorSubstring.substr(result.length), openCount); + } + } + return result; +} + +function tryMatch(str: string, rgx: RegExp, index = 1) { + let found = ''; + try { + found = str.match(rgx)?.[index] || ''; + } catch (e) { + found = ''; + } + return found; +} + +export function setAttributes(element: HTMLElement, data: {[name: string]: any}) { + let attr; + for (attr in data) { + if (data.hasOwnProperty(attr)) { + setAttribute(element, attr, data[attr]); + } + } +} + +function getHtmlFragment(value: string) { + if (value.match(/^\s*<.*?>[\s\S]*<\/[a-z0-9]+>\s*$/)) { + return document.createRange().createContextualFragment(value.trim()); + } + return null; +} + +function getChild(value: any) : HTMLElement|DocumentFragment|Text|null { + if (value instanceof Node) { + return value as HTMLElement; + } + if (typeof value === 'string') { + let child = getHtmlFragment(value); + if (child) { + return child; + } + if (getSelectorTag(value)) { + return createElement(value); + } + return createText(value); + } + if (value && typeof value.toElement === 'function') { + return value.toElement() as HTMLElement; + } + return null; +} + + +export function setAttribute(element: HTMLElement, attribute: string, value: any) { + let isRemoving = value == null; + + if (attribute === 'default') { + attribute = RGX_DEFAULT_VALUE_PROP.test(element.nodeName) ? 'value' : 'text'; + } + + if (attribute === 'text') { + empty(element).appendChild(createText(value != null ? String(value) : '')); + + } else if (attribute === 'html') { + empty(element).innerHTML += value != null ? String(value) : ''; + + } else if (attribute == 'style') { + if (typeof value === 'string') { + element.style.cssText = isRemoving ? '' : (value != null ? String(value) : ''); + } else { + for (const [styleKey, styleValue] of Object.entries(value as {[key: string]: any})) { + element.style[styleKey as 'display'] = styleValue; + } + } + + } else if (attribute == 'events') { + for (const [key, fn] of Object.entries(value as {[key: string]: (e: Event) => void})) { + addEvent(element, key, fn); + } + + } else if (attribute === 'parent') { + value.appendChild(element); + + } else if (attribute === 'child' || attribute === 'children') { + // Try to handle an array, like [li,li,li]. Not nested brackets, though + if (typeof value === 'string' && /^\[[^\[\]]+\]$/.test(value)) { + const parseable = value.replace(/^\[([^\[\]]+)\]$/, '["$1"]').replace(/,/g, '","'); + try { + const parsed = JSON.parse(parseable); + value = parsed; + } catch(e) { + console.error(e); + } + } + + // "children" is a replace of the children, while "child" appends a new child if others exist. + if (attribute === 'children') { + empty(element); + } + + let children = value instanceof Array ? value : [value]; + for (let child of children) { + child = getChild(child); + if (child instanceof Node) { + if (element instanceof HTMLTemplateElement) { + element.content.appendChild(child); + } else { + element.appendChild(child); + } + } + } + + } else if (attribute == 'for') { + (element as HTMLLabelElement).htmlFor = value != null ? String(value) : ''; + if (isRemoving) { + // delete (element as HTMLLabelElement).htmlFor; + element.removeAttribute('for'); + } + + } else if (attribute === 'class' || attribute === 'className' || attribute === 'classes') { + element.className = isRemoving ? '' : Array.isArray(value) ? value.join(' ') : String(value); + + } else if (attribute === 'dataset') { + if (typeof value !== 'object') { + console.error('Expecting an object for dataset'); + return; + } + for (const [key, val] of Object.entries(value)) { + element.dataset[key] = String(val); + } + + } else if (attribute == 'onclick' && typeof value === 'function') { + element.addEventListener('click', value); + + } else if (['checked', 'disabled', 'readonly', 'required', 'selected'].includes(attribute)) { + // Could be input, button, etc. We are not discriminate. + (element as HTMLInputElement)[attribute as 'checked'] = !!value; + if (!value) { + (element as HTMLInputElement).removeAttribute(attribute); + } else { + (element as HTMLInputElement).setAttribute(attribute, attribute); + } + + } else if (DIRECT_ATTRIBUTE_MAP.hasOwnProperty(attribute)) { + if (isRemoving) { + element.removeAttribute(DIRECT_ATTRIBUTE_MAP[attribute]!); + } else { + element.setAttribute(DIRECT_ATTRIBUTE_MAP[attribute]!, String(value)); + } + + } else if (isRemoving) { + element.removeAttribute(attribute); + + } else { + let oldVal = element.getAttribute(attribute); + if (oldVal !== value) { + element.setAttribute(attribute, String(value)); + } + } +} + +function addEvent(element: HTMLElement, key: string, fn: (e:Event) => void) { + element.addEventListener(key, fn); +} + +function setStyles(element: HTMLElement, styles: {[name: string]: string|number}|null = null) { + if (styles) { + for (let name in styles) { + setStyle(element, name, styles[name]!); + } + } + return element; +} + +function setStyle(element: HTMLElement, name: string, value: string|number|null) { + // Note: Old IE uses 'styleFloat' + name = (name.indexOf('float') > -1 ? 'cssFloat' : name); + // Camelcase + if (name.indexOf('-') != -1) { + name = name.replace(/-\D/g, (match) => { + return match.charAt(1).toUpperCase(); + }); + } + if (value == String(Number(value)) && RGX_NUMERIC_STYLE.test(name)) { + value = value + RGX_NUMERIC_STYLE_UNIT; + } + if (name === 'display' && typeof value !== 'string') { + value = !!value ? null : 'none'; + } + (element.style as any)[name] = value === null ? null : String(value); + return element; +}; + +export function empty(element: HTMLElement) { + while (element.firstChild) { + element.removeChild(element.firstChild); + } + return element; +} + +type ChildType = HTMLElement|DocumentFragment|Text|string|null; +export function appendChildren(el: HTMLElement, children: ChildType|ChildType[]) { + children = !Array.isArray(children) ? [children] : children; + for (let child of children) { + child = getChild(child); + if (child instanceof Node) { + if (el instanceof HTMLTemplateElement) { + el.content.appendChild(child); + } else { + el.appendChild(child); + } + } + } +} diff --git a/rgthree-comfy/src_web/common/utils_workflow.ts b/rgthree-comfy/src_web/common/utils_workflow.ts new file mode 100644 index 0000000000000000000000000000000000000000..5967bcb10dec50e2df001cc351d5d1adba0f1303 --- /dev/null +++ b/rgthree-comfy/src_web/common/utils_workflow.ts @@ -0,0 +1,71 @@ +import { getResolver } from "./shared_utils.js"; +import { getPngMetadata, getWebpMetadata } from "scripts/pnginfo.js"; +import type { SerializedGraph } from "typings/index.js"; +import type { ComfyApiFormat } from "typings/comfy.js"; + +/** + * Parses the workflow JSON and do any necessary cleanup. + */ +function parseWorkflowJson(stringJson?: string) { + stringJson = stringJson || "null"; + // Starting around August 2024 the serialized JSON started to get messy and contained `NaN` (for + // an is_changed property, specifically). NaN is not parseable, so we'll get those on out of there + // and cleanup anything else we need. + stringJson = stringJson.replace(/:\s*NaN/g, ": null"); + return JSON.parse(stringJson); +} + +export async function tryToGetWorkflowDataFromEvent( + e: DragEvent, +): Promise<{ workflow: SerializedGraph | null; prompt: ComfyApiFormat | null }> { + let work; + for (const file of e.dataTransfer?.files || []) { + const data = await tryToGetWorkflowDataFromFile(file); + if (data.workflow || data.prompt) { + return data; + } + } + const validTypes = ["text/uri-list", "text/x-moz-url"]; + const match = (e.dataTransfer?.types || []).find((t) => validTypes.find((v) => t === v)); + if (match) { + const uri = e.dataTransfer!.getData(match)?.split("\n")?.[0]; + if (uri) { + return tryToGetWorkflowDataFromFile(await (await fetch(uri)).blob()); + } + } + return { workflow: null, prompt: null }; +} + +export async function tryToGetWorkflowDataFromFile( + file: File | Blob, +): Promise<{ workflow: SerializedGraph | null; prompt: ComfyApiFormat | null }> { + if (file.type === "image/png") { + const pngInfo = await getPngMetadata(file); + return { + workflow: parseWorkflowJson(pngInfo?.workflow), + prompt: parseWorkflowJson(pngInfo?.prompt), + }; + } + + if (file.type === "image/webp") { + const pngInfo = await getWebpMetadata(file); + // Support loading workflows from that webp custom node. + const workflow = parseWorkflowJson(pngInfo?.workflow || pngInfo?.Workflow || "null"); + const prompt = parseWorkflowJson(pngInfo?.prompt || pngInfo?.Prompt || "null"); + return { workflow, prompt }; + } + + if (file.type === "application/json" || (file as File).name?.endsWith(".json")) { + const resolver = getResolver<{ workflow: any; prompt: any }>(); + const reader = new FileReader(); + reader.onload = async () => { + const json = parseWorkflowJson(reader.result as string); + const isApiJson = Object.values(json).every((v: any) => v.class_type); + const prompt = isApiJson ? json : null; + const workflow = !isApiJson && !json?.templates ? json : null; + return { workflow, prompt }; + }; + return resolver.promise; + } + return { workflow: null, prompt: null }; +} diff --git a/rgthree-comfy/src_web/link_fixer/icon_file_json.png b/rgthree-comfy/src_web/link_fixer/icon_file_json.png new file mode 100644 index 0000000000000000000000000000000000000000..ad3a1cb2b89a2051010d53d69b12cca8735af353 Binary files /dev/null and b/rgthree-comfy/src_web/link_fixer/icon_file_json.png differ diff --git a/rgthree-comfy/src_web/link_fixer/index.html b/rgthree-comfy/src_web/link_fixer/index.html new file mode 100644 index 0000000000000000000000000000000000000000..e998f8e07a8e8787cd2eb8c54c1b935b4cd4a12b --- /dev/null +++ b/rgthree-comfy/src_web/link_fixer/index.html @@ -0,0 +1,126 @@ + + + + rgthree's comfy: Workflow Link Fixer + + + + +
+

rgthree's Workflow Link Fixer

+

Early versions of the reroute node would occasionally leave behind stale node-linking data in the graph, which could sometimes cause erratic workflow loading. This tool will look at the metadata and attempt to fix these errors.

+

Drag and drop a comfy-generated image or workflow json into this window to check its serialized links and fix.

+ +
+ + + +
+
+ +
+ + + +
+ + + + \ No newline at end of file diff --git a/rgthree-comfy/src_web/link_fixer/link_page.ts b/rgthree-comfy/src_web/link_fixer/link_page.ts new file mode 100644 index 0000000000000000000000000000000000000000..d266ef2ece219d414a78b0f3e58299e5f81d3d73 --- /dev/null +++ b/rgthree-comfy/src_web/link_fixer/link_page.ts @@ -0,0 +1,235 @@ +import type { SerializedGraph, BadLinksData } from "typings/index.js"; +import { fixBadLinks } from "../common/link_fixer.js"; +import { getPngMetadata } from "scripts/pnginfo.js"; + +function wait(ms = 16, value?: any) { + return new Promise((resolve) => { + setTimeout(() => { + resolve(value); + }, ms); + }); +} + +const logger = { + logTo: console as Console | HTMLElement, + log: (...args: any[]) => { + logger.logTo === console + ? console.log(...args) + : ((logger.logTo as HTMLElement).innerText += args.join(",") + "\n"); + }, +}; + +const findBadLinksLogger = { + log: async (...args: any[]) => { + logger.log(...args); + // await wait(48); + }, +}; + +export class LinkPage { + private containerEl: HTMLDivElement; + private figcaptionEl: HTMLElement; + private btnFix: HTMLButtonElement; + private outputeMessageEl: HTMLDivElement; + private outputImageEl: HTMLImageElement; + + private file?: File | Blob; + private graph?: SerializedGraph; + private graphResults?: BadLinksData; + private graphFinalResults?: BadLinksData; + + constructor() { + // const consoleEl = document.getElementById("console")!; + this.containerEl = document.querySelector(".box")!; + this.figcaptionEl = document.querySelector("figcaption")!; + this.outputeMessageEl = document.querySelector(".output")!; + this.outputImageEl = document.querySelector(".output-image")!; + this.btnFix = document.querySelector(".btn-fix")!; + + // Need to prevent on dragover to allow drop... + document.addEventListener( + "dragover", + (e) => { + e.preventDefault(); + }, + false, + ); + document.addEventListener("drop", (e) => { + this.onDrop(e); + }); + this.btnFix.addEventListener("click", (e) => { + this.onFixClick(e); + }); + } + + private async onFixClick(e: MouseEvent) { + if (!this.graphResults || !this.graph) { + this.updateUi("⛔ Fix button click without results."); + return; + } + // Fix + let graphFinalResults = fixBadLinks(this.graph, true); + // Confirm + graphFinalResults = fixBadLinks(graphFinalResults.graph, true); + // This should have happened, but try to run it through again if there's till an issue. + if (graphFinalResults.patched || graphFinalResults.deleted) { + graphFinalResults = fixBadLinks(graphFinalResults.graph, true); + } + this.graphFinalResults = graphFinalResults; + + await this.saveFixedWorkflow(); + + if (graphFinalResults.hasBadLinks) { + this.updateUi( + "⛔ Hmm... Still detecting bad links. Can you file an issue at https://github.com/rgthree/rgthree-comfy/issues with your image/workflow.", + ); + } else { + this.updateUi( + "✅ Workflow fixed.

Please load new saved workflow json and double check linking and execution.", + ); + } + } + + private async onDrop(event: DragEvent) { + if (!event.dataTransfer) { + return; + } + this.reset(); + + event.preventDefault(); + event.stopPropagation(); + + // Dragging from Chrome->Firefox there is a file but its a bmp, so ignore that + if (event.dataTransfer.files.length && event.dataTransfer.files?.[0]?.type !== "image/bmp") { + await this.handleFile(event.dataTransfer.files[0]!); + return; + } + + // Try loading the first URI in the transfer list + const validTypes = ["text/uri-list", "text/x-moz-url"]; + const match = [...event.dataTransfer.types].find((t) => validTypes.find((v) => t === v)); + if (match) { + const uri = event.dataTransfer.getData(match)?.split("\n")?.[0]; + if (uri) { + await this.handleFile(await (await fetch(uri)).blob()); + } + } + } + + reset() { + this.file = undefined; + this.graph = undefined; + this.graphResults = undefined; + this.graphFinalResults = undefined; + this.updateUi(); + } + + private updateUi(msg?: string) { + this.outputeMessageEl.innerHTML = ""; + if (this.file && !this.containerEl.classList.contains("-has-file")) { + this.containerEl.classList.add("-has-file"); + this.figcaptionEl.innerHTML = (this.file as File).name || this.file.type; + if (this.file.type === "application/json") { + this.outputImageEl.src = "icon_file_json.png"; + } else { + const reader = new FileReader(); + reader.onload = () => (this.outputImageEl.src = reader.result as string); + reader.readAsDataURL(this.file); + } + } else if (!this.file && this.containerEl.classList.contains("-has-file")) { + this.containerEl.classList.remove("-has-file"); + this.outputImageEl.src = ""; + this.outputImageEl.removeAttribute("src"); + } + + if (this.graphResults) { + this.containerEl.classList.add("-has-results"); + if (!this.graphResults.patched && !this.graphResults.deleted) { + this.outputeMessageEl.innerHTML = "✅ No bad links detected in the workflow."; + } else { + this.containerEl.classList.add("-has-fixable-results"); + this.outputeMessageEl.innerHTML = `⚠️ Found ${this.graphResults.patched} links to fix, and ${this.graphResults.deleted} to be removed.`; + } + } else { + this.containerEl.classList.remove("-has-results"); + this.containerEl.classList.remove("-has-fixable-results"); + } + + if (msg) { + this.outputeMessageEl.innerHTML = msg; + } + } + + private async handleFile(file: File | Blob) { + this.file = file; + this.updateUi(); + + let workflow: string | undefined | null = null; + if (file.type.startsWith("image/")) { + const pngInfo = await getPngMetadata(file); + workflow = pngInfo?.workflow; + } else if ( + file.type === "application/json" || + (file instanceof File && file.name.endsWith(".json")) + ) { + workflow = await new Promise((resolve) => { + const reader = new FileReader(); + reader.onload = () => { + resolve(reader.result as string); + }; + reader.readAsText(file); + }); + } + if (!workflow) { + this.updateUi("⛔ No workflow found in dropped item."); + } else { + try { + this.graph = JSON.parse(workflow); + } catch (e) { + this.graph = undefined; + } + if (!this.graph) { + this.updateUi("⛔ Invalid workflow found in dropped item."); + } else { + this.loadGraphData(this.graph); + } + } + } + + private async loadGraphData(graphData: SerializedGraph) { + this.graphResults = await fixBadLinks(graphData); + this.updateUi(); + } + + private async saveFixedWorkflow() { + if (!this.graphFinalResults) { + this.updateUi("⛔ Save w/o final graph patched."); + return false; + } + + let filename: string | null = (this.file as File).name || "workflow.json"; + let filenames = filename.split("."); + filenames.pop(); + filename = filenames.join("."); + filename += "_fixed.json"; + filename = prompt("Save workflow as:", filename); + if (!filename) return false; + if (!filename.toLowerCase().endsWith(".json")) { + filename += ".json"; + } + const json = JSON.stringify(this.graphFinalResults.graph, null, 2); + const blob = new Blob([json], { type: "application/json" }); + const url = URL.createObjectURL(blob); + const anchor = document.createElement("a"); + anchor.download = filename; + anchor.href = url; + anchor.style.display = "none"; + document.body.appendChild(anchor); + await wait(); + anchor.click(); + await wait(); + anchor.remove(); + window.URL.revokeObjectURL(url); + return true; + } +} diff --git a/rgthree-comfy/src_web/scripts_comfy/README.md b/rgthree-comfy/src_web/scripts_comfy/README.md new file mode 100644 index 0000000000000000000000000000000000000000..dac5571c7add88362732a1ca74c4849049a2410e --- /dev/null +++ b/rgthree-comfy/src_web/scripts_comfy/README.md @@ -0,0 +1,6 @@ +Here lies dummy ts files that decalre/export ComfyUI's own scripts files as typed types w/o needing +to symlink to the actual implementation. + +Actual code in the comfyui/ directory can import these like `import {app} from "/scripts/app.js"` +and have access to `app` as the fully typed `ComfyApp`. The `__build__.py` script will rewrite these +to the relative browser path. \ No newline at end of file diff --git a/rgthree-comfy/src_web/scripts_comfy/api.ts b/rgthree-comfy/src_web/scripts_comfy/api.ts new file mode 100644 index 0000000000000000000000000000000000000000..61cab401ecf62f39299acc1e8c104ec27ba50561 --- /dev/null +++ b/rgthree-comfy/src_web/scripts_comfy/api.ts @@ -0,0 +1,7 @@ +interface ComfyApi extends EventTarget { + getNodeDefs(): any; + apiURL(url: string): string; + queuePrompt(num: number, data: { output: {}; workflow: {} }): Promise<{}>; +} + +export declare const api: ComfyApi; diff --git a/rgthree-comfy/src_web/scripts_comfy/app.ts b/rgthree-comfy/src_web/scripts_comfy/app.ts new file mode 100644 index 0000000000000000000000000000000000000000..8ef5f5274745d6b61cc652e1c5022c099ffc22c2 --- /dev/null +++ b/rgthree-comfy/src_web/scripts_comfy/app.ts @@ -0,0 +1,7 @@ +import { ComfyApp } from "../typings/comfy.js"; + +/** + * A dummy ComfyApp that we can import from our code, which we'll rewrite later to the comfyui + * hosted app.js + */ +export declare const app: ComfyApp; diff --git a/rgthree-comfy/src_web/scripts_comfy/pnginfo.ts b/rgthree-comfy/src_web/scripts_comfy/pnginfo.ts new file mode 100644 index 0000000000000000000000000000000000000000..7aba24179ccb8ce047140c71bd5ba6c94b8b7279 --- /dev/null +++ b/rgthree-comfy/src_web/scripts_comfy/pnginfo.ts @@ -0,0 +1,7 @@ +export declare const getPngMetadata: (file: File | Blob) => { workflow?: string; prompt?: string }; +export declare const getWebpMetadata: (file: File | Blob) => { + Workflow?: string; + workflow?: string; + Prompt?: string; + prompt?: string; +}; diff --git a/rgthree-comfy/src_web/scripts_comfy/ui/components/button.ts b/rgthree-comfy/src_web/scripts_comfy/ui/components/button.ts new file mode 100644 index 0000000000000000000000000000000000000000..865eebadc89d0fb438a94fe017ccc2258b6c32e9 --- /dev/null +++ b/rgthree-comfy/src_web/scripts_comfy/ui/components/button.ts @@ -0,0 +1,23 @@ +import type { ComfyApp } from "typings/comfy.js"; + +type ComfyButtonProps = { + icon?: string; + overIcon?: string; + iconSize?: number; + content?: string | HTMLElement; + tooltip?: string; + enabled?: boolean; + action?: (e: Event, btn: ComfyButton) => void; + classList?: string; + visibilitySetting?: { id: string, showValue: any }; + app?: ComfyApp; +} + +export declare class ComfyButton { + element: HTMLElement; + iconElement: HTMLElement; + contentElement: HTMLElement; + constructor(props: ComfyButtonProps); + updateIcon(): void; + withPopup(popup: any, mode: "click"|"hover"): this; +}; diff --git a/rgthree-comfy/src_web/scripts_comfy/ui/components/buttonGroup.ts b/rgthree-comfy/src_web/scripts_comfy/ui/components/buttonGroup.ts new file mode 100644 index 0000000000000000000000000000000000000000..af8954a8ac59f80ea162094e8a61284bb34c9953 --- /dev/null +++ b/rgthree-comfy/src_web/scripts_comfy/ui/components/buttonGroup.ts @@ -0,0 +1,10 @@ +import type {ComfyButton} from "scripts/ui/components/button.js"; + +export declare class ComfyButtonGroup { + element: HTMLElement; + constructor(...buttons: Array); + insert(button: ComfyButton, index: number): void; + append(button: ComfyButton): void; + remove(indexOrButton: ComfyButton|number): ComfyButton|HTMLElement|void; + update(): void; +}; diff --git a/rgthree-comfy/src_web/scripts_comfy/ui/components/popup.ts b/rgthree-comfy/src_web/scripts_comfy/ui/components/popup.ts new file mode 100644 index 0000000000000000000000000000000000000000..a8eb20284c419ee57e2bd3f0b924a89dbc528cf9 --- /dev/null +++ b/rgthree-comfy/src_web/scripts_comfy/ui/components/popup.ts @@ -0,0 +1,16 @@ +type ComfyPopupProps = { + target: HTMLElement; + container?: HTMLElement; + classList?: string; + ignoreTarget?: boolean, + closeOnEscape?: boolean, + position?: "absolute" | "relative", + horizontal?: "left" | "right" +} + +export declare class ComfyPopup extends EventTarget { + element: HTMLDivElement; + constructor(props: ComfyPopupProps, ...children: HTMLElement[]); + toggle(): void; + update(): void; +}; diff --git a/rgthree-comfy/src_web/scripts_comfy/widgets.ts b/rgthree-comfy/src_web/scripts_comfy/widgets.ts new file mode 100644 index 0000000000000000000000000000000000000000..ee8e9c61561dbf82fccbfdb74f123b80e9bc07d6 --- /dev/null +++ b/rgthree-comfy/src_web/scripts_comfy/widgets.ts @@ -0,0 +1,19 @@ +import type { LGraphNode } from "typings/litegraph.js"; +import type { ComfyApp, ComfyWidget } from "../typings/comfy.js"; + +type ComfyWidgetFn = ( + node: LGraphNode, + inputName: string, + inputData: any, + app: ComfyApp, +) => { widget: ComfyWidget }; + +/** + * A dummy ComfyWidgets that we can import from our code, which we'll rewrite later to the comfyui + * hosted widgets.js + */ +export declare const ComfyWidgets: { + COMBO: ComfyWidgetFn; + STRING: ComfyWidgetFn; + [key: string]: ComfyWidgetFn; +}; diff --git a/rgthree-comfy/src_web/typings/README.md b/rgthree-comfy/src_web/typings/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b8a367e1c870a2009342f6f78b5f89959405935d --- /dev/null +++ b/rgthree-comfy/src_web/typings/README.md @@ -0,0 +1,3 @@ +The typings in node_modules or in ComfyUI's web/ directory were not that well covered. These typings are hacked together with some of the inconsistencies I found. + +To be honest, I have no idea why I needed a bizarre workaround for litegraph's types. Usually the '/// <reference>' comment should have picked up the types, but it wasn't having it. ¯\_(ツ)_/¯ \ No newline at end of file diff --git a/rgthree-comfy/src_web/typings/comfy.d.ts b/rgthree-comfy/src_web/typings/comfy.d.ts new file mode 100644 index 0000000000000000000000000000000000000000..bbecb80cf261133aa01b880bc986f55f53f0076c --- /dev/null +++ b/rgthree-comfy/src_web/typings/comfy.d.ts @@ -0,0 +1,227 @@ +import type { LGraphGroup as TLGraphGroup, LGraphNode as TLGraphNode, IWidget, SerializedLGraphNode, LGraph as TLGraph, LGraphCanvas as TLGraphCanvas, LiteGraph as TLiteGraph } from "./litegraph.js"; +import type {Constructor, SerializedGraph} from './index.js'; + +declare global { + const LiteGraph: typeof TLiteGraph; + const LGraph: typeof TLGraph; + const LGraphNode: typeof TLGraphNode; + const LGraphCanvas: typeof TLGraphCanvas; + const LGraphGroup: typeof TLGraphGroup; +} + +// @rgthree: Types on ComfyApp as needed. +export interface ComfyApp { + extensions: ComfyExtension[]; + async queuePrompt(number?: number, batchCount = 1): Promise; + graph: TLGraph; + canvas: TLGraphCanvas; + clean() : void; + registerExtension(extension: ComfyExtension): void; + getPreviewFormatParam(): string; + getRandParam(): string; + loadApiJson(apiData: {}, fileName: string): void; + async graphToPrompt(graph?: TLGraph, clean?: boolean): Promise; + // workflow: ComfyWorkflowInstance ??? + async loadGraphData(graphData: {}, clean?: boolean, restore_view?: boolean, workflow?: any|null): Promise + ui: { + settings: { + addSetting(config: {id: string, name: string, type: () => HTMLElement}) : void; + } + } + // Just marking as any for now. + menu?: any; +} + +export interface ComfyWidget extends IWidget { + // https://github.com/comfyanonymous/ComfyUI/issues/2193 Changes from SerializedLGraphNode to + // LGraphNode... + serializeValue(nodeType: TLGraphNode, index: number): Promise; + afterQueued(): void; + inputEl?: HTMLTextAreaElement; + width: number; +} + +export interface ComfyGraphNode extends TLGraphNode { + getExtraMenuOptions: (node: TLGraphNode, options: ContextMenuItem[]) => void; + onExecuted(message: any): void; +} + +export interface ComfyNode extends TLGraphNode { + comfyClass: string; +} + +// @rgthree +export interface ComfyNodeConstructor extends Constructor { + static title: string; + static type?: string; + static comfyClass: string; +} + +export type NodeMode = 0|1|2|3|4|undefined; + + +export interface ComfyExtension { + /** + * The name of the extension + */ + name: string; + /** + * Allows any initialisation, e.g. loading resources. Called after the canvas is created but before nodes are added + * @param app The ComfyUI app instance + */ + init?(app: ComfyApp): Promise; + /** + * Allows any additonal setup, called after the application is fully set up and running + * @param app The ComfyUI app instance + */ + setup?(app: ComfyApp): Promise; + /** + * Called before nodes are registered with the graph + * @param defs The collection of node definitions, add custom ones or edit existing ones + * @param app The ComfyUI app instance + */ + addCustomNodeDefs?(defs: Record, app: ComfyApp): Promise; + /** + * Allows the extension to add custom widgets + * @param app The ComfyUI app instance + * @returns An array of {[widget name]: widget data} + */ + getCustomWidgets?( + app: ComfyApp + ): Promise< + Record { widget?: IWidget; minWidth?: number; minHeight?: number }> + >; + /** + * Allows the extension to add additional handling to the node before it is registered with LGraph + * @rgthree changed nodeType from `typeof LGraphNode` to `ComfyNodeConstructor` + * @param nodeType The node class (not an instance) + * @param nodeData The original node object info config object + * @param app The ComfyUI app instance + */ + beforeRegisterNodeDef?(nodeType: ComfyNodeConstructor, nodeData: ComfyObjectInfo, app: ComfyApp): Promise; + /** + * Allows the extension to register additional nodes with LGraph after standard nodes are added + * @param app The ComfyUI app instance + */ + // @rgthree - add void for non async + registerCustomNodes?(app: ComfyApp): void|Promise; + /** + * Allows the extension to modify a node that has been reloaded onto the graph. + * If you break something in the backend and want to patch workflows in the frontend + * This is the place to do this + * @param node The node that has been loaded + * @param app The ComfyUI app instance + */ + loadedGraphNode?(node: TLGraphNode, app: ComfyApp); + /** + * Allows the extension to run code after the constructor of the node + * @param node The node that has been created + * @param app The ComfyUI app instance + */ + nodeCreated?(node: TLGraphNode, app: ComfyApp); +} + +export type ComfyObjectInfo = { + name: string; + display_name?: string; + description?: string; + category: string; + input?: { + required?: Record; + optional?: Record; + hidden?: Record; + }; + output?: string[]; + output_name: string[]; + // @rgthree + output_node?: boolean; +}; + +export type ComfyObjectInfoConfig = [string | any[]] | [string | any[], any]; + +// @rgthree +type ComfyApiInputLink = [ + /** The id string of the connected node. */ + string, + /** The output index. */ + number, +] + +// @rgthree +export type ComfyApiFormatNode = { + "inputs": { + [input_name: string]: string|number|boolean|ComfyApiInputLink, + }, + "class_type": string, + "_meta": { + "title": string, + } +} + +// @rgthree +export type ComfyApiFormat = { + [node_id: string]: ComfyApiFormatNode +} + +// @rgthree +export type ComfyApiPrompt = { + workflow: SerializedGraph, + output: ComfyApiFormat, +} + +// @rgthree +export type ComfyApiEventDetailStatus = { + exec_info: { + queue_remaining: number; + }; +}; + +// @rgthree +export type ComfyApiEventDetailExecutionStart = { + prompt_id: string; +}; + +// @rgthree +export type ComfyApiEventDetailExecuting = null | string; + +// @rgthree +export type ComfyApiEventDetailProgress = { + node: string; + prompt_id: string; + max: number; + value: number; +}; + +// @rgthree +export type ComfyApiEventDetailExecuted = { + node: string; + prompt_id: string; + output: any; +}; + +// @rgthree +export type ComfyApiEventDetailCached = { + nodes: string[]; + prompt_id: string; +}; + +// @rgthree +export type ComfyApiEventDetailExecuted = { + prompt_id: string; + node: string; + output: any; +}; + +// @rgthree +export type ComfyApiEventDetailError = { + prompt_id: string; + exception_type: string; + exception_message: string; + node_id: string; + node_type: string; + node_id: string; + traceback: string; + executed: any[]; + current_inputs: {[key: string]: (number[]|string[])}; + current_outputs: {[key: string]: (number[]|string[])}; +} diff --git a/rgthree-comfy/src_web/typings/index.d.ts b/rgthree-comfy/src_web/typings/index.d.ts new file mode 100644 index 0000000000000000000000000000000000000000..a574726e3b9ddadd3019530146cf29eea7d52188 --- /dev/null +++ b/rgthree-comfy/src_web/typings/index.d.ts @@ -0,0 +1,55 @@ +import { LGraph } from "./litegraph.js"; + +export type Constructor = new(...args: any[]) => T; + +export type SerializedLink = [ + number, // this.id, + number, // this.origin_id, + number, // this.origin_slot, + number, // this.target_id, + number, // this.target_slot, + string, // this.type +]; + +export interface SerializedNodeInput { + name: string; + type: string; + link: number; +} +export interface SerializedNodeOutput { + name: string; + type: string; + link: number; + slot_index: number; + links: number[]; +} +export interface SerializedNode { + id: number; + inputs: SerializedNodeInput[]; + outputs: SerializedNodeOutput[]; + mode: number; + order: number; + pos: [number, number]; + properties: any; + size: [number, number]; + type: string; + widgets_values: Array; +} + +export interface SerializedGraph { + config: any; + extra: any; + groups: any; + last_link_id: number; + last_node_id: number; + links: SerializedLink[]; + nodes: SerializedNode[]; +} + +export interface BadLinksData { + hasBadLinks: boolean; + fixed: boolean; + graph: T; + patched: number; + deleted: number; +} diff --git a/rgthree-comfy/src_web/typings/litegraph.d.ts b/rgthree-comfy/src_web/typings/litegraph.d.ts new file mode 100644 index 0000000000000000000000000000000000000000..5bbc23865d03670f17d6f71eb307de2ada545358 --- /dev/null +++ b/rgthree-comfy/src_web/typings/litegraph.d.ts @@ -0,0 +1,1748 @@ +// Type definitions for litegraph.js 0.7.0 +// Project: litegraph.js +// Definitions by: NateScarlet + +export type Vector2 = [number, number]; +export type Vector4 = [number, number, number, number]; +export type widgetTypes = + | "number" + | "slider" + | "combo" + | "text" + | "toggle" + | "button"; +export type SlotShape = + | typeof LiteGraph.BOX_SHAPE + | typeof LiteGraph.CIRCLE_SHAPE + | typeof LiteGraph.ARROW_SHAPE + | typeof LiteGraph.SQUARE_SHAPE + | number; // For custom shapes + +/** https://github.com/jagenjo/litegraph.js/tree/master/guides#node-slots */ +export interface INodeSlot { + name: string; + type: string | -1; + label?: string; + dir?: + | typeof LiteGraph.UP + | typeof LiteGraph.RIGHT + | typeof LiteGraph.DOWN + | typeof LiteGraph.LEFT; + color_on?: string; + color_off?: string; + shape?: SlotShape; + locked?: boolean; + nameLocked?: boolean; + pos?: Vector2; + // @rgthree + hidden?: boolean; + // @rgthree + disabled?: boolean; + // @rgthree - Found this checked in getSlotMenuOptions default. + removable?: boolean; + // @rgthree - A status we put on some nodes so we can draw things around it. + rgthree_status?: 'WARN' | 'ERROR'; +} + +export interface INodeInputSlot extends INodeSlot { + link: LLink["id"] | null; + // @rgthree - add comfy widget info + widget?: { + name: string; + } +} + +export interface INodeOutputSlot extends INodeSlot { + links: LLink["id"][] | null; +} + +export type WidgetCallback = ( + this: T, + value: T["value"], + graphCanvas: LGraphCanvas, + node: LGraphNode, + pos: Vector2, + event?: MouseEvent +) => void; + +// @rgthree +export type WidgetComboCallback = ( + this: T, + value: T["value"][0], + graphCanvas: LGraphCanvas, + node: LGraphNode, + pos: Vector2, + event?: MouseEvent +) => void; + +// @rgthree +export type IWidgetOptions = { + y?: number; // ? + property?: string; + serialize?: boolean; // ComfyUI in app.js + forceInput?: boolean; // ComfyUI in app.js + defaultInput?: boolean; // ComfyUI in app.js +} + +// @rgthree +export type IWidgetToggleOptions = IWidgetOptions & { + on?: string; + off?: string; +} + +// @rgthree +export type IWidgetNumberOptions = IWidgetOptions & { + precision?: number; + max?: number; + min?: number; +} + +// @rgthree +export type IWidgetSliderOptions = IWidgetNumberOptions & { + slider_color?: string; + marker_color?: string; +} + + +// @rgthree +export type IWidgetComboOptions = IWidgetOptions & { + values?: string[] | ((widget: IComboWidget, node: LGraphNode) => string[]); +} + +export interface IWidget { + name: string | null; + // @rgthree + label?: string | null; + value: TValue; + options?: TOptions; + // @rgthree - extend to string for custom + type?: widgetTypes | string; + y?: number; + property?: string; + last_y?: number; + clicked?: boolean; + marker?: boolean; + disabled?: boolean; + callback?: WidgetCallback; + /** Called by `LGraphCanvas.drawNodeWidgets` */ + draw?( + ctx: CanvasRenderingContext2D, + node: LGraphNode, + width: number, + posY: number, + height: number + ): void; + /** + * Called by `LGraphCanvas.processNodeWidgets` + * https://github.com/jagenjo/litegraph.js/issues/76 + */ + mouse?( + event: MouseEvent, + pos: Vector2, + node: LGraphNode + ): boolean; + /** Called by `LGraphNode.computeSize` */ + computeSize?(width: number): [number, number]; + // @rgthree - make optional, since it is in the code. + serializeValue?(serializedNode: SerializedLGraphNode, widgetIndex: number): TValue; + // @rgthree - Checked in LGraphCanvas.prototype.processNodeWidgets, and figured I'd use it too. + width?: number; +} +export interface IButtonWidget extends IWidget { + type: "button"; +} +// @rgthree: adding options +export interface IToggleWidget extends IWidget { + type: "toggle"; +} +// @rgthree: adding options +export interface ISliderWidget extends IWidget { + type: "slider"; +} +// @rgthree: adding options +export interface INumberWidget extends IWidget { + type: "number"; +} +// @rgthree: adding options +export interface IComboWidget extends IWidget { + value: T[0]; + type: "combo"; + callback?: WidgetComboCallback; +} + +export interface ITextWidget extends IWidget { + type: "text"; +} + +export interface IContextMenuItem { + // @rgthree - Make optional because, I guess it is? + content?: string; + value?: any; + callback?: ContextMenuEventListener; + /** Used as innerHTML for extra child element */ + title?: string; + disabled?: boolean; + has_submenu?: boolean; + submenu?: { + options: ContextMenuItem[]; + } & IContextMenuOptions; + className?: string; + // @rgthree - Added for menu_auto_nest + rgthree_originalValue?: IContextMenuItem; + // @rgthree - this was missing and passed through for getSlotMenuOptions default. + slot?: {input?: INodeInputSlot, output?: INodeOutputSlot}; +} +export interface IContextMenuOptions { + callback?: ContextMenuEventListener; + ignore_item_callbacks?: Boolean; + event?: MouseEvent | CustomEvent | AdjustedMouseEvent; + parentMenu?: ContextMenu|null; + autoopen?: boolean; + title?: string; + extra?: any; + // @rgthree + scale?: number; + // @rgthree + left?: number; + // @rgthree + top?: number; + // @rgthree + className?: string; + // @rgthree - Added for menu_auto_nest + rgthree_originalCallback?: ContextMenuEventListener; + // @rgthree - No idea since it's not documented, but we'll use it to pass data, like rgthree_doNotNest + extra?: any +} + +export type ContextMenuItem = IContextMenuItem | null; +export type ContextMenuEventListener = ( + value: ContextMenuItem, + options: IContextMenuOptions, + event: MouseEvent, + parentMenu: ContextMenu | undefined, + node: LGraphNode +) => boolean | void; + +export const LiteGraph: { + VERSION: number; + + CANVAS_GRID_SIZE: number; + + NODE_TITLE_HEIGHT: number; + NODE_TITLE_TEXT_Y: number; + NODE_SLOT_HEIGHT: number; + NODE_WIDGET_HEIGHT: number; + NODE_WIDTH: number; + NODE_MIN_WIDTH: number; + NODE_COLLAPSED_RADIUS: number; + NODE_COLLAPSED_WIDTH: number; + NODE_TITLE_COLOR: string; + NODE_TEXT_SIZE: number; + NODE_TEXT_COLOR: string; + NODE_SUBTEXT_SIZE: number; + NODE_DEFAULT_COLOR: string; + NODE_DEFAULT_BGCOLOR: string; + NODE_DEFAULT_BOXCOLOR: string; + NODE_DEFAULT_SHAPE: string; + // @rgthree - didn't exist. + NODE_BOX_OUTLINE_COLOR: string; + DEFAULT_SHADOW_COLOR: string; + DEFAULT_GROUP_FONT: number; + // @rgthree - Seems to have been missing. + NODE_BOX_OUTLINE_COLOR: string; + + WIDGET_BGCOLOR: string; + WIDGET_OUTLINE_COLOR: string; + WIDGET_TEXT_COLOR: string; + WIDGET_SECONDARY_TEXT_COLOR: string; + + LINK_COLOR: string; + EVENT_LINK_COLOR: string; + CONNECTING_LINK_COLOR: string; + + MAX_NUMBER_OF_NODES: number; //avoid infinite loops + DEFAULT_POSITION: Vector2; //default node position + VALID_SHAPES: ["default", "box", "round", "card"]; //,"circle" + + //shapes are used for nodes but also for slots + BOX_SHAPE: 1; + ROUND_SHAPE: 2; + CIRCLE_SHAPE: 3; + CARD_SHAPE: 4; + ARROW_SHAPE: 5; + GRID_SHAPE: 6; + + //enums + INPUT: 1; + OUTPUT: 2; + + EVENT: -1; //for outputs + ACTION: -1; //for inputs + + ALWAYS: 0; + ON_EVENT: 1; + NEVER: 2; + ON_TRIGGER: 3; + + UP: 1; + DOWN: 2; + LEFT: 3; + RIGHT: 4; + CENTER: 5; + + STRAIGHT_LINK: 0; + LINEAR_LINK: 1; + SPLINE_LINK: 2; + + NORMAL_TITLE: 0; + NO_TITLE: 1; + TRANSPARENT_TITLE: 2; + AUTOHIDE_TITLE: 3; + + node_images_path: string; + + // @rgthree. These just weren't there. Note, LiteGraph initializes these as an array, but + // ComfyUI overrides these to a string-keye'd object... ??? + slot_types_default_out: {[key: string]: string[]}; + slot_types_default_in: {[key: string]: string[]}; + + debug: boolean; + catch_exceptions: boolean; + throw_errors: boolean; + /** if set to true some nodes like Formula would be allowed to evaluate code that comes from unsafe sources (like node configuration), which could lead to exploits */ + allow_scripts: boolean; + /** node types by string */ + registered_node_types: Record; + /** used for dropping files in the canvas */ + node_types_by_file_extension: Record; + /** node types by class name */ + Nodes: Record; + + /** used to add extra features to the search box */ + searchbox_extras: Record< + string, + { + data: { outputs: string[][]; title: string }; + desc: string; + type: string; + } + >; + + //@rgthree + isValidConnection(type: string|string[], type: string|string[]):boolean; + overlapBounding(a: Vector4, b: Vector4) : boolean; + + createNode(type: string): T; + /** Register a node class so it can be listed when the user wants to create a new one */ + registerNodeType(type: string, base: { new (title?: string): T }): void; + /** removes a node type from the system */ + unregisterNodeType(type: string): void; + /** Removes all previously registered node's types. */ + clearRegisteredTypes(): void; + /** + * Create a new node type by passing a function, it wraps it with a proper class and generates inputs according to the parameters of the function. + * Useful to wrap simple methods that do not require properties, and that only process some input to generate an output. + * @param name node name with namespace (p.e.: 'math/sum') + * @param func + * @param param_types an array containing the type of every parameter, otherwise parameters will accept any type + * @param return_type string with the return type, otherwise it will be generic + * @param properties properties to be configurable + */ + wrapFunctionAsNode( + name: string, + func: (...args: any[]) => any, + param_types?: string[], + return_type?: string, + properties?: object + ): void; + + /** + * Adds this method to all node types, existing and to be created + * (You can add it to LGraphNode.prototype but then existing node types wont have it) + */ + addNodeMethod(name: string, func: (...args: any[]) => any): void; + + /** + * Create a node of a given type with a name. The node is not attached to any graph yet. + * @param type full name of the node class. p.e. "math/sin" + * @param name a name to distinguish from other nodes + * @param options to set options + */ + createNode( + type: string, + title: string, + options: object + ): T; + + /** + * Returns a registered node type with a given name + * @param type full name of the node class. p.e. "math/sin" + */ + getNodeType(type: string): LGraphNodeConstructor; + + /** + * Returns a list of node types matching one category + * @method getNodeTypesInCategory + * @param {String} category category name + * @param {String} filter only nodes with ctor.filter equal can be shown + * @return {Array} array with all the node classes + */ + getNodeTypesInCategory( + category: string, + filter: string + ): LGraphNodeConstructor[]; + + /** + * Returns a list with all the node type categories + * @method getNodeTypesCategories + * @param {String} filter only nodes with ctor.filter equal can be shown + * @return {Array} array with all the names of the categories + */ + getNodeTypesCategories(filter: string): string[]; + + /** debug purposes: reloads all the js scripts that matches a wildcard */ + reloadNodes(folder_wildcard: string): void; + + getTime(): number; + LLink: typeof LLink; + LGraph: typeof LGraph; + DragAndScale: typeof DragAndScale; + compareObjects(a: object, b: object): boolean; + distance(a: Vector2, b: Vector2): number; + colorToString(c: string): string; + isInsideRectangle( + x: number, + y: number, + left: number, + top: number, + width: number, + height: number + ): boolean; + growBounding(bounding: Vector4, x: number, y: number): Vector4; + isInsideBounding(p: Vector2, bb: Vector4): boolean; + hex2num(hex: string): [number, number, number]; + num2hex(triplet: [number, number, number]): string; + ContextMenu: typeof ContextMenu; + extendClass(target: A, origin: B): A & B; + getParameterNames(func: string): string[]; + // @rgthree + closeAllContextMenus(ref_window?: Window): void; +}; + +export type serializedLGraph< + TNode = ReturnType, + // https://github.com/jagenjo/litegraph.js/issues/74 + TLink = [number, number, number, number, number, string], + TGroup = ReturnType +> = { + last_node_id: LGraph["last_node_id"]; + last_link_id: LGraph["last_link_id"]; + nodes: TNode[]; + links: TLink[]; + groups: TGroup[]; + config: LGraph["config"]; + version: typeof LiteGraph.VERSION; +}; + +export declare class LGraph { + static supported_types: string[]; + static STATUS_STOPPED: 1; + static STATUS_RUNNING: 2; + + constructor(o?: object); + + filter: string; + catch_errors: boolean; + /** custom data */ + config: object; + elapsed_time: number; + fixedtime: number; + fixedtime_lapse: number; + globaltime: number; + inputs: any; + iteration: number; + last_link_id: number; + last_node_id: number; + last_update_time: number; + links: Record; + list_of_graphcanvas: LGraphCanvas[]; + outputs: any; + runningtime: number; + starttime: number; + status: typeof LGraph.STATUS_RUNNING | typeof LGraph.STATUS_STOPPED; + + // @rgthree, remove private; it's not really private b/c it's javascript. + _nodes: LGraphNode[]; + // @rgthree, remove private; it's not really private b/c it's javascript. + _groups: LGraphGroup[]; + private _nodes_by_id: Record; + /** nodes that are executable sorted in execution order */ + private _nodes_executable: + | (LGraphNode & { onExecute: NonNullable }[]) + | null; + /** nodes that contain onExecute */ + private _nodes_in_order: LGraphNode[]; + private _version: number; + + getSupportedTypes(): string[]; + /** Removes all nodes from this graph */ + clear(): void; + /** Attach Canvas to this graph */ + attachCanvas(graphCanvas: LGraphCanvas): void; + /** Detach Canvas to this graph */ + detachCanvas(graphCanvas: LGraphCanvas): void; + /** + * Starts running this graph every interval milliseconds. + * @param interval amount of milliseconds between executions, if 0 then it renders to the monitor refresh rate + */ + start(interval?: number): void; + /** Stops the execution loop of the graph */ + stop(): void; + /** + * Run N steps (cycles) of the graph + * @param num number of steps to run, default is 1 + */ + runStep(num?: number, do_not_catch_errors?: boolean): void; + /** + * Updates the graph execution order according to relevance of the nodes (nodes with only outputs have more relevance than + * nodes with only inputs. + */ + updateExecutionOrder(): void; + /** This is more internal, it computes the executable nodes in order and returns it */ + computeExecutionOrder(only_onExecute: boolean, set_level: any): T; + /** + * Returns all the nodes that could affect this one (ancestors) by crawling all the inputs recursively. + * It doesn't include the node itself + * @return an array with all the LGraphNodes that affect this node, in order of execution + */ + getAncestors(node: LGraphNode): LGraphNode[]; + /** + * Positions every node in a more readable manner + */ + arrange(margin?: number,layout?: string): void; + /** + * Returns the amount of time the graph has been running in milliseconds + * @return number of milliseconds the graph has been running + */ + getTime(): number; + + /** + * Returns the amount of time accumulated using the fixedtime_lapse var. This is used in context where the time increments should be constant + * @return number of milliseconds the graph has been running + */ + getFixedTime(): number; + + /** + * Returns the amount of time it took to compute the latest iteration. Take into account that this number could be not correct + * if the nodes are using graphical actions + * @return number of milliseconds it took the last cycle + */ + getElapsedTime(): number; + /** + * Sends an event to all the nodes, useful to trigger stuff + * @param eventName the name of the event (function to be called) + * @param params parameters in array format + */ + sendEventToAllNodes(eventName: string, params: any[], mode?: any): void; + + sendActionToCanvas(action: any, params: any[]): void; + /** + * Adds a new node instance to this graph + * @param node the instance of the node + */ + add(node: LGraphNode, skip_compute_order?: boolean): void; + /** + * Called when a new node is added + * @param node the instance of the node + */ + onNodeAdded(node: LGraphNode): void; + /** Removes a node from the graph */ + remove(node: LGraphNode): void; + /** + * Returns a node by its id. + * @rgthree - make id options/nullable. + */ + getNodeById(id?: number|null): LGraphNode | undefined; + /** + * Returns a list of nodes that matches a class + * @param classObject the class itself (not an string) + * @return a list with all the nodes of this type + */ + findNodesByClass( + classObject: LGraphNodeConstructor + ): T[]; + /** + * Returns a list of nodes that matches a type + * @param type the name of the node type + * @return a list with all the nodes of this type + */ + findNodesByType(type: string): T[]; + /** + * Returns the first node that matches a name in its title + * @param title the name of the node to search + * @return the node or null + */ + findNodeByTitle(title: string): T | null; + /** + * Returns a list of nodes that matches a name + * @param title the name of the node to search + * @return a list with all the nodes with this name + */ + findNodesByTitle(title: string): T[]; + /** + * Returns the top-most node in this position of the canvas + * @param x the x coordinate in canvas space + * @param y the y coordinate in canvas space + * @param nodes_list a list with all the nodes to search from, by default is all the nodes in the graph + * @return the node at this position or null + */ + getNodeOnPos( + x: number, + y: number, + node_list?: LGraphNode[], + margin?: number + ): T | null; + /** + * Returns the top-most group in that position + * @param x the x coordinate in canvas space + * @param y the y coordinate in canvas space + * @return the group or null + */ + getGroupOnPos(x: number, y: number): LGraphGroup | null; + + onAction(action: any, param: any): void; + trigger(action: any, param: any): void; + /** Tell this graph it has a global graph input of this type */ + addInput(name: string, type: string, value?: any): void; + /** Assign a data to the global graph input */ + setInputData(name: string, data: any): void; + /** Returns the current value of a global graph input */ + getInputData(name: string): T; + /** Changes the name of a global graph input */ + renameInput(old_name: string, name: string): false | undefined; + /** Changes the type of a global graph input */ + changeInputType(name: string, type: string): false | undefined; + /** Removes a global graph input */ + removeInput(name: string): boolean; + /** Creates a global graph output */ + addOutput(name: string, type: string, value: any): void; + /** Assign a data to the global output */ + setOutputData(name: string, value: string): void; + /** Returns the current value of a global graph output */ + getOutputData(name: string): T; + + /** Renames a global graph output */ + renameOutput(old_name: string, name: string): false | undefined; + /** Changes the type of a global graph output */ + changeOutputType(name: string, type: string): false | undefined; + /** Removes a global graph output */ + removeOutput(name: string): boolean; + triggerInput(name: string, value: any): void; + setCallback(name: string, func: (...args: any[]) => any): void; + beforeChange(info?: LGraphNode): void; + afterChange(info?: LGraphNode): void; + connectionChange(node: LGraphNode): void; + /** returns if the graph is in live mode */ + isLive(): boolean; + /** clears the triggered slot animation in all links (stop visual animation) */ + clearTriggeredSlots(): void; + /* Called when something visually changed (not the graph!) */ + change(): void; + setDirtyCanvas(fg: boolean, bg: boolean): void; + /** Destroys a link */ + removeLink(link_id: number): void; + /** Creates a Object containing all the info about this graph, it can be serialized */ + serialize(): T; + /** + * Configure a graph from a JSON string + * @param data configure a graph from a JSON string + * @returns if there was any error parsing + */ + configure(data: object, keep_old?: boolean): boolean | undefined; + load(url: string): void; +} + +export type SerializedLLink = [number, string, number, number, number, number]; +export declare class LLink { + id: number; + type: string; + origin_id: number; + origin_slot: number; + target_id: number; + target_slot: number; + constructor( + id: number, + type: string, + origin_id: number, + origin_slot: number, + target_id: number, + target_slot: number + ); + configure(o: LLink | SerializedLLink): void; + serialize(): SerializedLLink; + // @rgthree + color?: string; +} + +export type SerializedLGraphNode = { + id: T["id"]; + type: T["type"]; + pos: T["pos"]; + size: T["size"]; + flags: T["flags"]; + mode: T["mode"]; + inputs: T["inputs"]; + outputs: T["outputs"]; + title: T["title"]; + properties: T["properties"]; + widgets_values?: IWidget["value"][]; +}; + +/** https://github.com/jagenjo/litegraph.js/blob/master/guides/README.md#lgraphnode */ +export declare class LGraphNode { + + // @rgthree added + findInputSlotByType(type: string, returnObj?: boolean, preferFreeSlot?: boolean, doNotUseOccupied?: boolean): number + findOutputSlotByType(type: string, returnObj?: boolean, preferFreeSlot?: boolean, doNotUseOccupied?: boolean): number + onShowCustomPanelInfo(panel: HTMLElement): void; + onDblClick?(event: AdjustedMouseEvent, pos: Vector2, canvas: LGraphCanvas): void; + inResizeCorner(x: number, y:number) : boolean; + onWidgetChanged?(widgetName: string, widgetValue: any, oldWidgetValue: any, widget: IWidget): void; + + // end @rgthree added + + static title_color?: string; + static title: string; + static type: null | string; + static widgets_up: boolean; + constructor(title?: string); + + title: string; + // @rgthree - made undefined since ComfyNode does it (even through LiteGraph does not..) + type?: null | string; + size: Vector2; + graph: null | LGraph; + graph_version: number; + pos: Vector2; + is_selected: boolean; + mouseOver: boolean; + // @rgthree - missing. + block_delete: boolean; + + id: number; + + widgets: IWidget[]; + //inputs available: array of inputs + inputs: INodeInputSlot[]; + outputs: INodeOutputSlot[]; + connections: any[]; + + // @rgthree + _collapsed_width?: number; + + //local data + properties: Record; + properties_info: any[]; + + flags: Partial<{ + collapsed: boolean + // @rgthree + allow_interaction: boolean; + // @rgthree + pinned: boolean; + }>; + + color: string; + bgcolor: string; + boxcolor: string; + shape: + | typeof LiteGraph.BOX_SHAPE + | typeof LiteGraph.ROUND_SHAPE + | typeof LiteGraph.CIRCLE_SHAPE + | typeof LiteGraph.CARD_SHAPE + | typeof LiteGraph.ARROW_SHAPE; + + serialize_widgets: boolean; + skip_list: boolean; + + /** Used in `LGraphCanvas.onMenuNodeMode` */ + mode?: + | typeof LiteGraph.ON_EVENT + | typeof LiteGraph.ON_TRIGGER + | typeof LiteGraph.NEVER + | typeof LiteGraph.ALWAYS + | 4; // Comfy App "Bypass" + + /** If set to true widgets do not start after the slots */ + widgets_up: boolean; + /** widgets start at y distance from the top of the node */ + widgets_start_y: number; + /** if you render outside the node, it will be clipped */ + clip_area: boolean; + /** if set to false it wont be resizable with the mouse */ + resizable: boolean; + /** slots are distributed horizontally */ + horizontal: boolean; + /** if true, the node will show the bgcolor as 'red' */ + has_errors?: boolean; + + // @rgthree + setSize(size: Vector2): void; + onResize?(size: Vector2): void; + onInputClick(slot: number, event: MouseEvent): void; + onOutputClick(slot: number, event: MouseEvent): void; + getConnectionPos(isInput: boolean, slotNumber: number, out: Vector2): Vector2; + + /** configure a node from an object containing the serialized info */ + configure(info: SerializedLGraphNode): void; + /** serialize the content */ + serialize(): SerializedLGraphNode; + /** Creates a clone of this node */ + clone(): this; + /** serialize and stringify */ + toString(): string; + /** get the title string */ + getTitle(): string; + /** sets the value of a property */ + setProperty(name: string, value: any): void; + /** sets the output data */ + setOutputData(slot: number, data: any): void; + /** sets the output data */ + setOutputDataType(slot: number, type: string): void; + /** + * Retrieves the input data (data traveling through the connection) from one slot + * @param slot + * @param force_update if set to true it will force the connected node of this slot to output data into this link + * @return data or if it is not connected returns undefined + */ + getInputData(slot: number, force_update?: boolean): T; + /** + * Retrieves the input data type (in case this supports multiple input types) + * @param slot + * @return datatype in string format + */ + getInputDataType(slot: number): string; + /** + * Retrieves the input data from one slot using its name instead of slot number + * @param slot_name + * @param force_update if set to true it will force the connected node of this slot to output data into this link + * @return data or if it is not connected returns null + */ + getInputDataByName(slot_name: string, force_update?: boolean): T; + /** tells you if there is a connection in one input slot */ + isInputConnected(slot: number): boolean; + /** tells you info about an input connection (which node, type, etc) */ + getInputInfo( + slot: number + ): { link: number; name: string; type: string | 0 } | null; + /** returns the node connected in the input slot */ + getInputNode(slot: number): LGraphNode | null; + /** returns the value of an input with this name, otherwise checks if there is a property with that name */ + getInputOrProperty(name: string): T; + /** tells you the last output data that went in that slot */ + getOutputData(slot: number): T | null; + /** tells you info about an output connection (which node, type, etc) */ + getOutputInfo( + slot: number + ): { name: string; type: string; links: number[] } | null; + /** tells you if there is a connection in one output slot */ + isOutputConnected(slot: number): boolean; + /** tells you if there is any connection in the output slots */ + isAnyOutputConnected(): boolean; + /** retrieves all the nodes connected to this output slot */ + getOutputNodes(slot: number): LGraphNode[]; + /** Triggers an event in this node, this will trigger any output with the same name */ + trigger(action: string, param: any): void; + /** + * Triggers an slot event in this node + * @param slot the index of the output slot + * @param param + * @param link_id in case you want to trigger and specific output link in a slot + */ + triggerSlot(slot: number, param: any, link_id?: number): void; + /** + * clears the trigger slot animation + * @param slot the index of the output slot + * @param link_id in case you want to trigger and specific output link in a slot + */ + clearTriggeredSlot(slot: number, link_id?: number): void; + /** + * add a new property to this node + * @param name + * @param default_value + * @param type string defining the output type ("vec3","number",...) + * @param extra_info this can be used to have special properties of the property (like values, etc) + */ + addProperty( + name: string, + default_value: any, + type: string, + extra_info?: object + ): T; + /** + * add a new output slot to use in this node + * @param name + * @param type string defining the output type ("vec3","number",...) + * @param extra_info this can be used to have special properties of an output (label, special color, position, etc) + */ + addOutput( + name: string, + type: string | -1, + extra_info?: Partial + ): INodeOutputSlot; + /** + * add a new output slot to use in this node + * @param array of triplets like [[name,type,extra_info],[...]] + */ + addOutputs( + array: [string, string | -1, Partial | undefined][] + ): void; + /** remove an existing output slot */ + removeOutput(slot: number): void; + /** + * add a new input slot to use in this node + * @param name + * @param type string defining the input type ("vec3","number",...), it its a generic one use 0 + * @param extra_info this can be used to have special properties of an input (label, color, position, etc) + */ + addInput( + name: string, + type: string | -1, + extra_info?: Partial + ): INodeInputSlot; + /** + * add several new input slots in this node + * @param array of triplets like [[name,type,extra_info],[...]] + */ + addInputs( + array: [string, string | -1, Partial | undefined][] + ): void; + /** remove an existing input slot */ + removeInput(slot: number): void; + /** + * add an special connection to this node (used for special kinds of graphs) + * @param name + * @param type string defining the input type ("vec3","number",...) + * @param pos position of the connection inside the node + * @param direction if is input or output + */ + addConnection( + name: string, + type: string, + pos: Vector2, + direction: string + ): { + name: string; + type: string; + pos: Vector2; + direction: string; + links: null; + }; + setValue(v: any): void; + /** computes the size of a node according to its inputs and output slots */ + computeSize(out?: Vector2): [number, number]; + /** + * https://github.com/jagenjo/litegraph.js/blob/master/guides/README.md#node-widgets + * @return created widget + */ + addWidget( + type: T["type"], + name: string, + value: T["value"], + // @rgthree + callback?: T["callback"] | string,//WidgetCallback | string, + options?: T["options"] + ): T; + + addCustomWidget(customWidget: T): T; + + /** + * returns the bounding of the object, used for rendering purposes + * @return [x, y, width, height] + */ + getBounding(): Vector4; + /** checks if a point is inside the shape of a node */ + isPointInside( + x: number, + y: number, + margin?: number, + skipTitle?: boolean + ): boolean; + /** checks if a point is inside a node slot, and returns info about which slot */ + getSlotInPosition( + x: number, + y: number + ): { + input?: INodeInputSlot; + output?: INodeOutputSlot; + slot: number; + link_pos: Vector2; + }; + /** + * returns the input slot with a given name (used for dynamic slots), -1 if not found + * @param name the name of the slot + * @return the slot (-1 if not found) + */ + findInputSlot(name: string): number; + /** + * returns the output slot with a given name (used for dynamic slots), -1 if not found + * @param name the name of the slot + * @return the slot (-1 if not found) + */ + findOutputSlot(name: string): number; + /** + * connect this node output to the input of another node + * @param slot (could be the number of the slot or the string with the name of the slot) + * @param targetNode the target node + * @param targetSlot the input slot of the target node (could be the number of the slot or the string with the name of the slot, or -1 to connect a trigger) + * @return {Object} the link_info is created, otherwise null + */ + connect( + slot: number | string, + targetNode: LGraphNode, + targetSlot: number | string + ): T | null; + + connectByTypeOutput( + slot: number | string, + sourceNode: LGraphNode, + sourceSlotType: string, + optsIn: string + ): T | null; + + connectByType( + slot: number | string, + sourceNode: LGraphNode, + sourceSlotType: string, + optsIn: string + ): T | null; + + + /** + * disconnect one output to an specific node + * @param slot (could be the number of the slot or the string with the name of the slot) + * @param target_node the target node to which this slot is connected [Optional, if not target_node is specified all nodes will be disconnected] + * @return if it was disconnected successfully + */ + disconnectOutput(slot: number | string, targetNode?: LGraphNode): boolean; + /** + * disconnect one input + * @param slot (could be the number of the slot or the string with the name of the slot) + * @return if it was disconnected successfully + */ + disconnectInput(slot: number | string): boolean; + /** + * returns the center of a connection point in canvas coords + * @param is_input true if if a input slot, false if it is an output + * @param slot (could be the number of the slot or the string with the name of the slot) + * @param out a place to store the output, to free garbage + * @return the position + **/ + getConnectionPos( + is_input: boolean, + slot: number | string, + out?: Vector2 + ): Vector2; + /** Force align to grid */ + alignToGrid(): void; + /** Console output */ + trace(msg: string): void; + /** Forces to redraw or the main canvas (LGraphNode) or the bg canvas (links) */ + setDirtyCanvas(fg: boolean, bg: boolean): void; + loadImage(url: string): void; + /** Allows to get onMouseMove and onMouseUp events even if the mouse is out of focus */ + captureInput(v: any): void; + /** Collapse the node to make it smaller on the canvas */ + collapse(force: boolean): void; + /** Forces the node to do not move or realign on Z */ + pin(v?: boolean): void; + localToScreen(x: number, y: number, graphCanvas: LGraphCanvas): Vector2; + + // https://github.com/jagenjo/litegraph.js/blob/master/guides/README.md#custom-node-appearance + onDrawBackground?( + ctx: CanvasRenderingContext2D, + // @rgthree fixed + canvas: LGraphCanvas + ): void; + onDrawForeground?( + ctx: CanvasRenderingContext2D, + // @rgthree fixed + canvas: LGraphCanvas + ): void; + + // https://github.com/jagenjo/litegraph.js/blob/master/guides/README.md#custom-node-behaviour + onMouseDown?( + event: MouseEvent, + pos: Vector2, + graphCanvas: LGraphCanvas + ): void; + onMouseMove?( + event: MouseEvent, + pos: Vector2, + graphCanvas: LGraphCanvas + ): void; + onMouseUp?( + event: MouseEvent, + pos: Vector2, + graphCanvas: LGraphCanvas + ): void; + onMouseEnter?( + event: MouseEvent, + pos: Vector2, + graphCanvas: LGraphCanvas + ): void; + onMouseLeave?( + event: MouseEvent, + pos: Vector2, + graphCanvas: LGraphCanvas + ): void; + onKey?(event: KeyboardEvent, pos: Vector2, graphCanvas: LGraphCanvas): void; + // @rgthree + onKeyDown?(event: KeyboardEvent): void; + // @rgthree + onKeyUp?(event: KeyboardEvent): void; + + onExecuted(message: any): void; + /** Called by `LGraphNode.createNode` */ + onNodeCreated?(): void; + /** Called by `LGraphCanvas.selectNodes` */ + onSelected?(): void; + /** Called by `LGraphCanvas.deselectNode` */ + onDeselected?(): void; + /** Called by `LGraph.runStep` `LGraphNode.getInputData` */ + onExecute?(): void; + /** Called by `LGraph.serialize` */ + onSerialize?(o: SerializedLGraphNode): void; + /** Called by `LGraph.configure` */ + onConfigure?(o: SerializedLGraphNode): void; + /** + * when added to graph (warning: this is called BEFORE the node is configured when loading) + * Called by `LGraph.add` + */ + onAdded?(graph: LGraph): void; + /** + * when removed from graph + * Called by `LGraph.remove` `LGraph.clear` + */ + onRemoved?(): void; + /** + * if returns false the incoming connection will be canceled + * Called by `LGraph.connect` + * @param inputIndex target input slot number + * @param outputType type of output slot + * @param outputSlot output slot object + * @param outputNode node containing the output + * @param outputIndex index of output slot + */ + onConnectInput?( + inputIndex: number, + outputType: INodeOutputSlot["type"], + outputSlot: INodeOutputSlot, + outputNode: LGraphNode, + outputIndex: number + ): boolean; + /** + * if returns false the incoming connection will be canceled + * Called by `LGraph.connect` + * @param outputIndex target output slot number + * @param inputType type of input slot + * @param inputSlot input slot object + * @param inputNode node containing the input + * @param inputIndex index of input slot + */ + onConnectOutput?( + outputIndex: number, + inputType: INodeInputSlot["type"], + inputSlot: INodeInputSlot, + inputNode: LGraphNode, + inputIndex: number + ): boolean; + + /** + * Called just before connection (or disconnect - if input is linked). + * A convenient place to switch to another input, or create new one. + * This allow for ability to automatically add slots if needed + * @param inputIndex + * @return selected input slot index, can differ from parameter value + */ + onBeforeConnectInput?( + inputIndex: number + ): number; + + /** a connection changed (new one or removed) (LiteGraph.INPUT or LiteGraph.OUTPUT, slot, true if connected, link_info, input_info or output_info ) */ + onConnectionsChange( + type: number, + slotIndex: number, + isConnected: boolean, + link: LLink, + // @rgthree - Make it INodeSlot instead of union + ioSlot: INodeSlot + ): void; + + /** + * if returns false, will abort the `LGraphNode.setProperty` + * Called when a property is changed + * @param property + * @param value + * @param prevValue + */ + onPropertyChanged?(property: string, value: any, prevValue: any): void | boolean; + + /** Called by `LGraphCanvas.processContextMenu` */ + getMenuOptions?(graphCanvas: LGraphCanvas): ContextMenuItem[]; + // @rgthree. This is fixed because the INodeSlot is wrong and, also, null can be returned to not trigger a menu. + // getSlotMenuOptions?(slot: INodeSlot): ContextMenuItem[]; + getSlotMenuOptions(slot: {input?: INodeInputSlot, output?: INodeOutputSlot}): ContextMenuItem[] | null; + + getExtraMenuOptions?(canvas: LGraphCanvas, options: ContextMenuItem[]): void; + + // @rgthree - Called in LiteGraph.core when the properties panel is constructed. + onShowCustomPanelInfo(panel: HTMLElement): void; +} + +export type LGraphNodeConstructor = { + new (): T; + + // @rgthree + title_mode?: + typeof LiteGraph.NORMAL_TITLE | + typeof LiteGraph.TRANSPARENT_TITLE | + typeof LiteGraph.AUTOHIDE_TITLE | + typeof LiteGraph.NO_TITLE; + title: string; + category: string; + type: string; + comfyClass?: string; +}; + +export type SerializedLGraphGroup = { + title: LGraphGroup["title"]; + bounding: LGraphGroup["_bounding"]; + color: LGraphGroup["color"]; + font: LGraphGroup["font"]; +}; +export declare class LGraphGroup { + title: string; + // @rgthree - mark unprivate + _bounding: Vector4; + // @rgthree - updated to make optional because it seems to be so. + color?: string|null; + font: string; + // @rgthree + _nodes: LGraphNode[]; + // @rgthree + _pos: Vector2; + // @rgthree + _size: Vector2; + // @rgthree + graph: LGraph; + // @rgthree - apparently it is available? + size: Vector2; + // @rgthree - apparently it is available? + pos: Vector2; + + + configure(o: SerializedLGraphGroup): void; + serialize(): SerializedLGraphGroup; + move(deltaX: number, deltaY: number, ignoreNodes?: boolean): void; + recomputeInsideNodes(): void; + isPointInside: LGraphNode["isPointInside"]; + setDirtyCanvas: LGraphNode["setDirtyCanvas"]; +} + +export declare class DragAndScale { + constructor(element?: HTMLElement, skipEvents?: boolean); + offset: [number, number]; + scale: number; + max_scale: number; + min_scale: number; + onredraw: Function | null; + enabled: boolean; + last_mouse: Vector2; + element: HTMLElement | null; + visible_area: Vector4; + bindEvents(element: HTMLElement): void; + computeVisibleArea(): void; + onMouse(e: MouseEvent): void; + toCanvasContext(ctx: CanvasRenderingContext2D): void; + convertOffsetToCanvas(pos: Vector2): Vector2; + convertCanvasToOffset(pos: Vector2): Vector2; + mouseDrag(x: number, y: number): void; + changeScale(value: number, zooming_center?: Vector2): void; + changeDeltaScale(value: number, zooming_center?: Vector2): void; + reset(): void; +} + +// @rgthree. +interface CanvasDivDialog extends HTMLDivElement { + close: () => void; + modified: () => void; + is_modified: boolean; +} + +/** + * This class is in charge of rendering one graph inside a canvas. And provides all the interaction required. + * Valid callbacks are: onNodeSelected, onNodeDeselected, onShowNodePanel, onNodeDblClicked + * + * @param canvas the canvas where you want to render (it accepts a selector in string format or the canvas element itself) + * @param graph + * @param options { skip_rendering, autoresize } + */ +export declare class LGraphCanvas { + static node_colors: Record< + string, + { + color: string; + bgcolor: string; + groupcolor: string; + } + >; + static link_type_colors: Record; + static gradients: object; + static search_limit: number; + + static getFileExtension(url: string): string; + static decodeHTML(str: string): string; + + static onMenuCollapseAll(): void; + static onMenuNodeEdit(): void; + static onShowPropertyEditor( + item: any, + options: any, + e: any, + menu: any, + node: any + ): void; + /** Create menu for `Add Group` */ + static onGroupAdd: ContextMenuEventListener; + /** Create menu for `Add Node` */ + static onMenuAdd: ContextMenuEventListener; + static showMenuNodeOptionalInputs: ContextMenuEventListener; + static showMenuNodeOptionalOutputs: ContextMenuEventListener; + static onShowMenuNodeProperties: ContextMenuEventListener; + static onResizeNode: ContextMenuEventListener; + static onMenuNodeCollapse: ContextMenuEventListener; + static onMenuNodePin: ContextMenuEventListener; + static onMenuNodeMode: ContextMenuEventListener; + static onMenuNodeColors: ContextMenuEventListener; + static onMenuNodeShapes: ContextMenuEventListener; + static onMenuNodeRemove: ContextMenuEventListener; + static onMenuNodeClone: ContextMenuEventListener; + + // @rgthree + static onShowPropertyEditor: ContextMenuEventListener; + + constructor( + canvas: HTMLCanvasElement | string, + graph?: LGraph, + options?: { + skip_render?: boolean; + autoresize?: boolean; + } + ); + + // @rgthree. This was "HTMLCanvasElement" but that is just wrong... it's LGraphCanvas + static active_canvas: LGraphCanvas; + + // @rgthree + pointer_is_down: boolean; + + allow_dragcanvas: boolean; + allow_dragnodes: boolean; + /** allow to control widgets, buttons, collapse, etc */ + allow_interaction: boolean; + /** allows to change a connection with having to redo it again */ + allow_reconnect_links: boolean; + /** allow selecting multi nodes without pressing extra keys */ + multi_select: boolean; + /** No effect */ + allow_searchbox: boolean; + always_render_background: boolean; + autoresize?: boolean; + background_image: string; + bgcanvas: HTMLCanvasElement; + bgctx: CanvasRenderingContext2D; + canvas: HTMLCanvasElement; + canvas_mouse: Vector2; + // @rgthree - Looks like this is to replace canvas_mouse. + graph_mouse: Vector2; + clear_background: boolean; + // connecting_node: LGraphNode | null; + // // @rgthree - for overriding. + // _connecting_node: LGraphNode | null; + // @rgthree + // connecting_input: INodeInputSlot | null; + // // @rgthree + // connecting_output: INodeOutputSlot | null; + // // @rgthree + // connecting_slot: number; + // // @rgthree + // connecting_pos: Vector2 | null; + // @rgthree - for some reason, the new comfyUI update renamed connecting_node to connecting_links (maybe) + connecting_links: { + node: LGraphNode, input?: INodeInputSlot, output?: INodeOutputSlot, pos: Vector2, slot: number + }[] | null; + _connecting_links: this['connecting_links'] | null; + + connections_width: number; + ctx: CanvasRenderingContext2D; + current_node: LGraphNode | null; + default_connection_color: { + input_off: string; + input_on: string; + output_off: string; + output_on: string; + }; + default_link_color: string; + dirty_area: Vector4 | null; + dirty_bgcanvas?: boolean; + dirty_canvas?: boolean; + drag_mode: boolean; + dragging_canvas: boolean; + dragging_rectangle: Vector4 | null; + // @rgthree; mark undefined. + // It doesn't look like this should ever be undefined.. but something changed in Comfy and folks + // reported https://github.com/rgthree/rgthree-comfy/issues/71. I couldn't reproduce, but we can + // handle it. + ds?: DragAndScale; + /** used for transition */ + editor_alpha: number; + filter: any; + fps: number; + frame: number; + graph: LGraph; + highlighted_links: Record; + highquality_render: boolean; + inner_text_font: string; + is_rendering: boolean; + last_draw_time: number; + last_mouse: Vector2; + /** + * Possible duplicated with `last_mouse` + * https://github.com/jagenjo/litegraph.js/issues/70 + */ + last_mouse_position: Vector2; + /** Timestamp of last mouse click, defaults to 0 */ + last_mouseclick: number; + links_render_mode: + | typeof LiteGraph.STRAIGHT_LINK + | typeof LiteGraph.LINEAR_LINK + | typeof LiteGraph.SPLINE_LINK; + live_mode: boolean; + node_capturing_input: LGraphNode | null; + node_dragged: LGraphNode | null; + node_in_panel: LGraphNode | null; + node_over: LGraphNode | null; + node_title_color: string; + node_widget: [LGraphNode, IWidget] | null; + /** Called by `LGraphCanvas.drawBackCanvas` */ + onDrawBackground: + | ((ctx: CanvasRenderingContext2D, visibleArea: Vector4) => void) + | null; + /** Called by `LGraphCanvas.drawFrontCanvas` */ + onDrawForeground: + | ((ctx: CanvasRenderingContext2D, visibleArea: Vector4) => void) + | null; + onDrawOverlay: ((ctx: CanvasRenderingContext2D) => void) | null; + /** Called by `LGraphCanvas.processMouseDown` */ + onMouse: ((event: MouseEvent) => boolean) | null; + /** Called by `LGraphCanvas.drawFrontCanvas` and `LGraphCanvas.drawLinkTooltip` */ + onDrawLinkTooltip: ((ctx: CanvasRenderingContext2D, link: LLink, _this: this) => void) | null; + /** Called by `LGraphCanvas.selectNodes` */ + onNodeMoved: ((node: LGraphNode) => void) | null; + /** Called by `LGraphCanvas.processNodeSelected` */ + onNodeSelected: ((node: LGraphNode) => void) | null; + /** Called by `LGraphCanvas.deselectNode` */ + onNodeDeselected: ((node: LGraphNode) => void) | null; + /** Called by `LGraphCanvas.processNodeDblClicked` */ + onShowNodePanel: ((node: LGraphNode) => void) | null; + /** Called by `LGraphCanvas.processNodeDblClicked` */ + onNodeDblClicked: ((node: LGraphNode) => void) | null; + /** Called by `LGraphCanvas.selectNodes` */ + onSelectionChange: ((nodes: Record) => void) | null; + /** Called by `LGraphCanvas.showSearchBox` */ + onSearchBox: + | (( + helper: Element, + value: string, + graphCanvas: LGraphCanvas + ) => string[]) + | null; + onSearchBoxSelection: + | ((name: string, event: MouseEvent, graphCanvas: LGraphCanvas) => void) + | null; + pause_rendering: boolean; + render_canvas_border: boolean; + render_collapsed_slots: boolean; + render_connection_arrows: boolean; + render_connections_border: boolean; + render_connections_shadows: boolean; + render_curved_connections: boolean; + render_execution_order: boolean; + render_only_selected: boolean; + render_shadows: boolean; + render_title_colored: boolean; + round_radius: number; + selected_group: null | LGraphGroup; + selected_group_resizing: boolean; + selected_nodes: Record; + show_info: boolean; + title_text_font: string; + /** set to true to render title bar with gradients */ + use_gradients: boolean; + visible_area: DragAndScale["visible_area"]; + visible_links: LLink[]; + visible_nodes: LGraphNode[]; + zoom_modify_alpha: boolean; + + /** clears all the data inside */ + clear(): void; + /** assigns a graph, you can reassign graphs to the same canvas */ + setGraph(graph: LGraph, skipClear?: boolean): void; + /** opens a graph contained inside a node in the current graph */ + openSubgraph(graph: LGraph): void; + /** closes a subgraph contained inside a node */ + closeSubgraph(): void; + /** assigns a canvas */ + setCanvas(canvas: HTMLCanvasElement, skipEvents?: boolean): void; + /** binds mouse, keyboard, touch and drag events to the canvas */ + bindEvents(): void; + /** unbinds mouse events from the canvas */ + unbindEvents(): void; + + /** + * this function allows to render the canvas using WebGL instead of Canvas2D + * this is useful if you plant to render 3D objects inside your nodes, it uses litegl.js for webgl and canvas2DtoWebGL to emulate the Canvas2D calls in webGL + **/ + enableWebGL(): void; + + /** + * marks as dirty the canvas, this way it will be rendered again + * @param fg if the foreground canvas is dirty (the one containing the nodes) + * @param bg if the background canvas is dirty (the one containing the wires) + */ + setDirty(fg: boolean, bg: boolean): void; + + /** + * Used to attach the canvas in a popup + * @return the window where the canvas is attached (the DOM root node) + */ + getCanvasWindow(): Window; + /** starts rendering the content of the canvas when needed */ + startRendering(): void; + /** stops rendering the content of the canvas (to save resources) */ + stopRendering(): void; + + processMouseDown(e: MouseEvent): boolean | undefined; + processMouseMove(e: MouseEvent): boolean | undefined; + processMouseUp(e: MouseEvent): boolean | undefined; + processMouseWheel(e: MouseEvent): boolean | undefined; + + /** returns true if a position (in graph space) is on top of a node little corner box */ + isOverNodeBox(node: LGraphNode, canvasX: number, canvasY: number): boolean; + /** returns true if a position (in graph space) is on top of a node input slot */ + isOverNodeInput( + node: LGraphNode, + canvasX: number, + canvasY: number, + slotPos: Vector2 + ): boolean; + + /** process a key event */ + processKey(e: KeyboardEvent): boolean | undefined; + + // @rgthree - added param + copyToClipboard(nodes: LGraphNode[]|{[key:number]:LGraphNode}): void; + pasteFromClipboard(): void; + processDrop(e: DragEvent): void; + checkDropItem(e: DragEvent): void; + processNodeDblClicked(n: LGraphNode): void; + processNodeSelected(n: LGraphNode, e: MouseEvent): void; + processNodeDeselected(node: LGraphNode): void; + + /** selects a given node (or adds it to the current selection) */ + selectNode(node: LGraphNode, add?: boolean): void; + /** selects several nodes (or adds them to the current selection) */ + selectNodes(nodes?: LGraphNode[], add?: boolean): void; + /** removes a node from the current selection */ + deselectNode(node: LGraphNode): void; + /** removes all nodes from the current selection */ + deselectAllNodes(): void; + /** deletes all nodes in the current selection from the graph */ + deleteSelectedNodes(): void; + + /** centers the camera on a given node */ + // @rgthree - narrow parameter + centerOnNode(node: {pos: Vector2, size: Vector2}): void; + /** changes the zoom level of the graph (default is 1), you can pass also a place used to pivot the zoom */ + setZoom(value: number, center: Vector2): void; + /** brings a node to front (above all other nodes) */ + bringToFront(node: LGraphNode): void; + /** sends a node to the back (below all other nodes) */ + sendToBack(node: LGraphNode): void; + /** checks which nodes are visible (inside the camera area) */ + computeVisibleNodes(nodes: LGraphNode[]): LGraphNode[]; + /** renders the whole canvas content, by rendering in two separated canvas, one containing the background grid and the connections, and one containing the nodes) */ + draw(forceFG?: boolean, forceBG?: boolean): void; + /** draws the front canvas (the one containing all the nodes) */ + drawFrontCanvas(): void; + /** draws some useful stats in the corner of the canvas */ + renderInfo(ctx: CanvasRenderingContext2D, x: number, y: number): void; + /** draws the back canvas (the one containing the background and the connections) */ + drawBackCanvas(): void; + /** draws the given node inside the canvas */ + drawNode(node: LGraphNode, ctx: CanvasRenderingContext2D): void; + /** draws graphic for node's slot */ + drawSlotGraphic(ctx: CanvasRenderingContext2D, pos: number[], shape: SlotShape, horizontal: boolean): void; + /** draws the shape of the given node in the canvas */ + drawNodeShape( + node: LGraphNode, + ctx: CanvasRenderingContext2D, + size: [number, number], + fgColor: string, + bgColor: string, + selected: boolean, + mouseOver: boolean + ): void; + /** draws every connection visible in the canvas */ + drawConnections(ctx: CanvasRenderingContext2D): void; + /** + * draws a link between two points + * @param a start pos + * @param b end pos + * @param link the link object with all the link info + * @param skipBorder ignore the shadow of the link + * @param flow show flow animation (for events) + * @param color the color for the link + * @param startDir the direction enum + * @param endDir the direction enum + * @param numSublines number of sublines (useful to represent vec3 or rgb) + **/ + renderLink( + a: Vector2, + b: Vector2, + link: object, + skipBorder: boolean, + flow: boolean, + color?: string, + startDir?: number, + endDir?: number, + numSublines?: number + ): void; + + computeConnectionPoint( + a: Vector2, + b: Vector2, + t: number, + startDir?: number, + endDir?: number + ): void; + + drawExecutionOrder(ctx: CanvasRenderingContext2D): void; + /** draws the widgets stored inside a node */ + drawNodeWidgets( + node: LGraphNode, + posY: number, + ctx: CanvasRenderingContext2D, + activeWidget: object + ): void; + /** process an event on widgets */ + processNodeWidgets( + node: LGraphNode, + pos: Vector2, + event: Event, + activeWidget: object + ): void; + /** draws every group area in the background */ + drawGroups(canvas: any, ctx: CanvasRenderingContext2D): void; + adjustNodesSize(): void; + /** resizes the canvas to a given size, if no size is passed, then it tries to fill the parentNode */ + resize(width?: number, height?: number): void; + /** + * switches to live mode (node shapes are not rendered, only the content) + * this feature was designed when graphs where meant to create user interfaces + **/ + switchLiveMode(transition?: boolean): void; + onNodeSelectionChange(): void; + touchHandler(event: TouchEvent): void; + + showLinkMenu(link: LLink, e: any): false; + prompt( + title: string, + value: any, + callback: Function, + event: any + ): HTMLDivElement; + showSearchBox(event?: MouseEvent): void; + showEditPropertyValue(node: LGraphNode, property: any, options: any): void; + createDialog( + html: string, + options?: { position?: Vector2; event?: MouseEvent } + // @rgthree - Fix return type from void (added above) + ): CanvasDivDialog; + + + + convertOffsetToCanvas: DragAndScale["convertOffsetToCanvas"]; + convertCanvasToOffset: DragAndScale["convertCanvasToOffset"]; + /** converts event coordinates from canvas2D to graph coordinates */ + // @rgthree - change MouseEvent to less restrictive {clientX: number, clientY: number} that + // implementation uses. + convertEventToCanvasOffset(e: {clientX: number, clientY: number}): Vector2; + /** adds some useful properties to a mouse event, like the position in graph coordinates */ + adjustMouseEvent(e: MouseEvent): void; + + getCanvasMenuOptions(): ContextMenuItem[]; + getNodeMenuOptions(node: LGraphNode): ContextMenuItem[]; + getGroupMenuOptions(): ContextMenuItem[]; + /** Called by `getCanvasMenuOptions`, replace default options */ + getMenuOptions?(): ContextMenuItem[]; + /** Called by `getCanvasMenuOptions`, append to default options */ + getExtraMenuOptions?(): ContextMenuItem[]; + /** Called when mouse right click */ + processContextMenu(node: LGraphNode, event: Event): void; + + // @rgthree - Adding this for ComfyUI, since they add this in their own overload in app.js + selected_group_moving?: boolean; + + // @rgthree + showShowNodePanel(node: LGraphNode): void; +} + +// @rgthree - The adjusted pointer event after calling adjustMouseEvent +export interface AdjustedMouseEvent extends PointerEvent { + deltaX: number; + deltaY: number; + canvasX: number; + canvasY: number; +} + +declare class ContextMenu { + static trigger( + element: HTMLElement, + event_name: string, + params: any, + origin: any + ): void; + static isCursorOverElement(event: MouseEvent, element: HTMLElement): void; + static closeAllContextMenus(window: Window): void; + constructor(values: ContextMenuItem[]|string[], options?: IContextMenuOptions, window?: Window); + options: IContextMenuOptions; + parentMenu?: ContextMenu; + lock: boolean; + current_submenu?: ContextMenu; + addItem( + name: string, + value: ContextMenuItem, + options?: IContextMenuOptions + ): void; + close(e?: MouseEvent, ignore_parent_menu?: boolean): void; + getTopMenu(): void; + getFirstEvent(): void; +} + +declare global { + interface Math { + clamp(v: number, min: number, max: number): number; + } +} diff --git a/rgthree-comfy/src_web/typings/rgthree.d.ts b/rgthree-comfy/src_web/typings/rgthree.d.ts new file mode 100644 index 0000000000000000000000000000000000000000..bf11fa2f51b624ca21f337efe0ee86c8eef191c2 --- /dev/null +++ b/rgthree-comfy/src_web/typings/rgthree.d.ts @@ -0,0 +1,67 @@ +import type { AdjustedMouseEvent, LGraphNode, Vector2 } from "./litegraph.js"; +import type {Constructor} from "./index.js"; +import type {RgthreeBaseVirtualNode} from '../comfyui/base_node.js' + +export type AdjustedMouseCustomEvent = CustomEvent<{ originalEvent: AdjustedMouseEvent }>; + + +export interface RgthreeBaseNodeConstructor extends Constructor { + static type: string; + static category: string; + static comfyClass: string; + static exposedActions: string[]; +} + +export interface RgthreeBaseVirtualNodeConstructor extends Constructor { + static type: string; + static category: string; + static _category: string; +} + + +export interface RgthreeBaseServerNodeConstructor extends Constructor { + static nodeType: ComfyNodeConstructor; + static nodeData: ComfyObjectInfo; + static __registeredForOverride__: boolean; + onRegisteredForOverride(comfyClass: any, rgthreeClass: any) : void; +} + + +export type RgthreeModelInfo = { + file?: string; + name?: string; + type?: string; + baseModel?: string; + baseModelFile?: string; + links?: string[]; + strengthMin?: number; + strengthMax?: number; + triggerWords?: string[]; + trainedWords?: { + word: string; + count?: number; + civitai?: boolean + user?: boolean + }[]; + description?: string; + sha256?: string; + path?: string; + images?: { + url: string; + civitaiUrl?: string; + steps?: string|number; + cfg?: string|number; + type?: 'image'|'video'; + sampler?: string; + model?: string; + seed?: string; + negative?: string; + positive?: string; + resources?: {name?: string, type?: string, weight?: string|number}[]; + }[] + userTags?: string[]; + userNote?: string; + raw?: any; + // This one is just on the client. + filterDir?: string; +} diff --git a/rgthree-comfy/tsconfig.json b/rgthree-comfy/tsconfig.json new file mode 100644 index 0000000000000000000000000000000000000000..47e3cca445a06903bc221bf0a0a9284e21145b7e --- /dev/null +++ b/rgthree-comfy/tsconfig.json @@ -0,0 +1,45 @@ +{ + "compilerOptions": { + "target": "es2019", + "module": "ESNext", + // "typeRoots": [ + // "./ts/typings", + // ], + "baseUrl": "./", + "paths": { + "typings/*": ["src_web/typings/*"], + "rgthree/common/*": ["src_web/common/*"], + "node_modules": ["node_modules/*"], + "scripts/*": ["src_web/scripts_comfy/*"], + }, + "outDir": "web/", + "removeComments": true, + "strict": true, + "noImplicitAny": true, + "strictNullChecks": true, + "strictFunctionTypes": true, + "strictBindCallApply": true, + "strictPropertyInitialization": true, + "noImplicitThis": true, + "useUnknownInCatchVariables": true, + "alwaysStrict": true, + // "noUnusedLocals": true, + // "noUnusedParameters": true, + "exactOptionalPropertyTypes": false, + "noImplicitReturns": true, + "noFallthroughCasesInSwitch": true, + "noUncheckedIndexedAccess": true, + "noImplicitOverride": true, + "noPropertyAccessFromIndexSignature": true, + "allowUnusedLabels": true, + "skipLibCheck": true, + }, + "include": [ + "src_web/*.ts", "src_web/**/*.ts", "src_web/typings/index.d.ts", + ], + "exclude": [ + "**/*.spec.ts", + "**/*.d.ts", + "node_modules/**/*.ts" + ] +} diff --git a/rgthree-comfy/web/comfyui/any_switch.js b/rgthree-comfy/web/comfyui/any_switch.js new file mode 100644 index 0000000000000000000000000000000000000000..50e0d796eb939bab92ef8084c1939d7cfdc66d02 --- /dev/null +++ b/rgthree-comfy/web/comfyui/any_switch.js @@ -0,0 +1,69 @@ +import { app } from "../../scripts/app.js"; +import { IoDirection, addConnectionLayoutSupport, followConnectionUntilType } from "./utils.js"; +import { RgthreeBaseServerNode } from "./base_node.js"; +import { NodeTypesString } from "./constants.js"; +import { removeUnusedInputsFromEnd } from "./utils_inputs_outputs.js"; +import { debounce } from "../../rgthree/common/shared_utils.js"; +class RgthreeAnySwitch extends RgthreeBaseServerNode { + constructor(title = RgthreeAnySwitch.title) { + super(title); + this.stabilizeBound = this.stabilize.bind(this); + this.nodeType = null; + this.addAnyInput(5); + } + onConnectionsChange(type, slotIndex, isConnected, linkInfo, ioSlot) { + var _a; + (_a = super.onConnectionsChange) === null || _a === void 0 ? void 0 : _a.call(this, type, slotIndex, isConnected, linkInfo, ioSlot); + this.scheduleStabilize(); + } + onConnectionsChainChange() { + this.scheduleStabilize(); + } + scheduleStabilize(ms = 64) { + return debounce(this.stabilizeBound, ms); + } + addAnyInput(num = 1) { + for (let i = 0; i < num; i++) { + this.addInput(`any_${String(this.inputs.length + 1).padStart(2, "0")}`, (this.nodeType || "*")); + } + } + stabilize() { + removeUnusedInputsFromEnd(this, 4); + this.addAnyInput(); + let connectedType = followConnectionUntilType(this, IoDirection.INPUT, undefined, true); + if (!connectedType) { + connectedType = followConnectionUntilType(this, IoDirection.OUTPUT, undefined, true); + } + this.nodeType = (connectedType === null || connectedType === void 0 ? void 0 : connectedType.type) || "*"; + for (const input of this.inputs) { + input.type = this.nodeType; + } + for (const output of this.outputs) { + output.type = this.nodeType; + output.label = + output.type === "RGTHREE_CONTEXT" + ? "CONTEXT" + : Array.isArray(this.nodeType) || this.nodeType.includes(",") + ? (connectedType === null || connectedType === void 0 ? void 0 : connectedType.label) || String(this.nodeType) + : String(this.nodeType); + } + } + static setUp(comfyClass, nodeData) { + RgthreeBaseServerNode.registerForOverride(comfyClass, nodeData, RgthreeAnySwitch); + addConnectionLayoutSupport(RgthreeAnySwitch, app, [ + ["Left", "Right"], + ["Right", "Left"], + ]); + } +} +RgthreeAnySwitch.title = NodeTypesString.ANY_SWITCH; +RgthreeAnySwitch.type = NodeTypesString.ANY_SWITCH; +RgthreeAnySwitch.comfyClass = NodeTypesString.ANY_SWITCH; +app.registerExtension({ + name: "rgthree.AnySwitch", + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (nodeData.name === "Any Switch (rgthree)") { + RgthreeAnySwitch.setUp(nodeType, nodeData); + } + }, +}); diff --git a/rgthree-comfy/web/comfyui/base_any_input_connected_node.js b/rgthree-comfy/web/comfyui/base_any_input_connected_node.js new file mode 100644 index 0000000000000000000000000000000000000000..83d196d87015a4333c97f378aee46d71a9453fc5 --- /dev/null +++ b/rgthree-comfy/web/comfyui/base_any_input_connected_node.js @@ -0,0 +1,195 @@ +import { app } from "../../scripts/app.js"; +import { RgthreeBaseVirtualNode } from "./base_node.js"; +import { rgthree } from "./rgthree.js"; +import { PassThroughFollowing, addConnectionLayoutSupport, addMenuItem, getConnectedInputNodes, getConnectedInputNodesAndFilterPassThroughs, getConnectedOutputNodes, getConnectedOutputNodesAndFilterPassThroughs, } from "./utils.js"; +export class BaseAnyInputConnectedNode extends RgthreeBaseVirtualNode { + constructor(title = BaseAnyInputConnectedNode.title) { + super(title); + this.isVirtualNode = true; + this.inputsPassThroughFollowing = PassThroughFollowing.NONE; + this.debouncerTempWidth = 0; + this.schedulePromise = null; + } + onConstructed() { + this.addInput("", "*"); + return super.onConstructed(); + } + scheduleStabilizeWidgets(ms = 100) { + if (!this.schedulePromise) { + this.schedulePromise = new Promise((resolve) => { + setTimeout(() => { + this.schedulePromise = null; + this.doStablization(); + resolve(); + }, ms); + }); + } + return this.schedulePromise; + } + clone() { + const cloned = super.clone(); + if (!rgthree.canvasCurrentlyCopyingToClipboardWithMultipleNodes) { + while (cloned.inputs.length > 1) { + cloned.removeInput(cloned.inputs.length - 1); + } + if (cloned.inputs[0]) { + cloned.inputs[0].label = ""; + } + } + return cloned; + } + stabilizeInputsOutputs() { + var _a; + const hasEmptyInput = !((_a = this.inputs[this.inputs.length - 1]) === null || _a === void 0 ? void 0 : _a.link); + if (!hasEmptyInput) { + this.addInput("", "*"); + } + for (let index = this.inputs.length - 2; index >= 0; index--) { + const input = this.inputs[index]; + if (!input.link) { + this.removeInput(index); + } + else { + const node = getConnectedInputNodesAndFilterPassThroughs(this, this, index, this.inputsPassThroughFollowing)[0]; + input.name = (node === null || node === void 0 ? void 0 : node.title) || ""; + } + } + } + doStablization() { + if (!this.graph) { + return; + } + this._tempWidth = this.size[0]; + const linkedNodes = getConnectedInputNodesAndFilterPassThroughs(this); + this.stabilizeInputsOutputs(); + this.handleLinkedNodesStabilization(linkedNodes); + app.graph.setDirtyCanvas(true, true); + this.scheduleStabilizeWidgets(500); + } + handleLinkedNodesStabilization(linkedNodes) { + linkedNodes; + throw new Error("handleLinkedNodesStabilization should be overridden."); + } + onConnectionsChainChange() { + this.scheduleStabilizeWidgets(); + } + onConnectionsChange(type, index, connected, linkInfo, ioSlot) { + super.onConnectionsChange && + super.onConnectionsChange(type, index, connected, linkInfo, ioSlot); + if (!linkInfo) + return; + const connectedNodes = getConnectedOutputNodesAndFilterPassThroughs(this); + for (const node of connectedNodes) { + if (node.onConnectionsChainChange) { + node.onConnectionsChainChange(); + } + } + this.scheduleStabilizeWidgets(); + } + removeInput(slot) { + this._tempWidth = this.size[0]; + return super.removeInput(slot); + } + addInput(name, type, extra_info) { + this._tempWidth = this.size[0]; + return super.addInput(name, type, extra_info); + } + addWidget(type, name, value, callback, options) { + this._tempWidth = this.size[0]; + return super.addWidget(type, name, value, callback, options); + } + removeWidget(widgetOrSlot) { + this._tempWidth = this.size[0]; + super.removeWidget(widgetOrSlot); + } + computeSize(out) { + var _a, _b; + let size = super.computeSize(out); + if (this._tempWidth) { + size[0] = this._tempWidth; + this.debouncerTempWidth && clearTimeout(this.debouncerTempWidth); + this.debouncerTempWidth = setTimeout(() => { + this._tempWidth = null; + }, 32); + } + if (this.properties["collapse_connections"]) { + const rows = Math.max(((_a = this.inputs) === null || _a === void 0 ? void 0 : _a.length) || 0, ((_b = this.outputs) === null || _b === void 0 ? void 0 : _b.length) || 0, 1) - 1; + size[1] = size[1] - rows * LiteGraph.NODE_SLOT_HEIGHT; + } + setTimeout(() => { + app.graph.setDirtyCanvas(true, true); + }, 16); + return size; + } + onConnectOutput(outputIndex, inputType, inputSlot, inputNode, inputIndex) { + let canConnect = true; + if (super.onConnectOutput) { + canConnect = super.onConnectOutput(outputIndex, inputType, inputSlot, inputNode, inputIndex); + } + if (canConnect) { + const nodes = getConnectedInputNodes(this); + if (nodes.includes(inputNode)) { + alert(`Whoa, whoa, whoa. You've just tried to create a connection that loops back on itself, ` + + `a situation that could create a time paradox, the results of which could cause a ` + + `chain reaction that would unravel the very fabric of the space time continuum, ` + + `and destroy the entire universe!`); + canConnect = false; + } + } + return canConnect; + } + onConnectInput(inputIndex, outputType, outputSlot, outputNode, outputIndex) { + let canConnect = true; + if (super.onConnectInput) { + canConnect = super.onConnectInput(inputIndex, outputType, outputSlot, outputNode, outputIndex); + } + if (canConnect) { + const nodes = getConnectedOutputNodes(this); + if (nodes.includes(outputNode)) { + alert(`Whoa, whoa, whoa. You've just tried to create a connection that loops back on itself, ` + + `a situation that could create a time paradox, the results of which could cause a ` + + `chain reaction that would unravel the very fabric of the space time continuum, ` + + `and destroy the entire universe!`); + canConnect = false; + } + } + return canConnect; + } + connectByTypeOutput(slot, sourceNode, sourceSlotType, optsIn) { + const lastInput = this.inputs[this.inputs.length - 1]; + if (!(lastInput === null || lastInput === void 0 ? void 0 : lastInput.link) && (lastInput === null || lastInput === void 0 ? void 0 : lastInput.type) === "*") { + var sourceSlot = sourceNode.findOutputSlotByType(sourceSlotType, false, true); + return sourceNode.connect(sourceSlot, this, slot); + } + return super.connectByTypeOutput(slot, sourceNode, sourceSlotType, optsIn); + } + static setUp() { + super.setUp(); + addConnectionLayoutSupport(this, app, [ + ["Left", "Right"], + ["Right", "Left"], + ]); + addMenuItem(this, app, { + name: (node) => { var _a; return `${((_a = node.properties) === null || _a === void 0 ? void 0 : _a["collapse_connections"]) ? "Show" : "Collapse"} Connections`; }, + property: "collapse_connections", + prepareValue: (_value, node) => { var _a; return !((_a = node.properties) === null || _a === void 0 ? void 0 : _a["collapse_connections"]); }, + callback: (_node) => { + app.graph.setDirtyCanvas(true, true); + }, + }); + } +} +const oldLGraphNodeConnectByType = LGraphNode.prototype.connectByType; +LGraphNode.prototype.connectByType = function connectByType(slot, sourceNode, sourceSlotType, optsIn) { + if (sourceNode.inputs) { + for (const [index, input] of sourceNode.inputs.entries()) { + if (!input.link && input.type === "*") { + this.connect(slot, sourceNode, index); + return null; + } + } + } + return ((oldLGraphNodeConnectByType && + oldLGraphNodeConnectByType.call(this, slot, sourceNode, sourceSlotType, optsIn)) || + null); +}; diff --git a/rgthree-comfy/web/comfyui/base_node.js b/rgthree-comfy/web/comfyui/base_node.js new file mode 100644 index 0000000000000000000000000000000000000000..ea34c57313ba8d832a4fe790537eece5180aebea --- /dev/null +++ b/rgthree-comfy/web/comfyui/base_node.js @@ -0,0 +1,288 @@ +import { ComfyWidgets } from "../../scripts/widgets.js"; +import { SERVICE as KEY_EVENT_SERVICE } from "./services/key_events_services.js"; +import { app } from "../../scripts/app.js"; +import { LogLevel, rgthree } from "./rgthree.js"; +import { addHelpMenuItem } from "./utils.js"; +import { RgthreeHelpDialog } from "../../rgthree/common/dialog.js"; +import { importIndividualNodesInnerOnDragDrop, importIndividualNodesInnerOnDragOver, } from "./feature_import_individual_nodes.js"; +export class RgthreeBaseNode extends LGraphNode { + constructor(title = RgthreeBaseNode.title, skipOnConstructedCall = true) { + super(title); + this.comfyClass = "__NEED_COMFY_CLASS__"; + this.nickname = "rgthree"; + this.isVirtualNode = false; + this.isDropEnabled = false; + this.removed = false; + this.configuring = false; + this._tempWidth = 0; + this.__constructed__ = false; + this.helpDialog = null; + if (title == "__NEED_CLASS_TITLE__") { + throw new Error("RgthreeBaseNode needs overrides."); + } + this.widgets = this.widgets || []; + this.properties = this.properties || {}; + setTimeout(() => { + if (this.comfyClass == "__NEED_COMFY_CLASS__") { + throw new Error("RgthreeBaseNode needs a comfy class override."); + } + this.checkAndRunOnConstructed(); + }); + } + checkAndRunOnConstructed() { + var _a; + if (!this.__constructed__) { + this.onConstructed(); + const [n, v] = rgthree.logger.logParts(LogLevel.DEV, `[RgthreeBaseNode] Child class did not call onConstructed for "${this.type}.`); + (_a = console[n]) === null || _a === void 0 ? void 0 : _a.call(console, ...v); + } + return this.__constructed__; + } + onDragOver(e) { + if (!this.isDropEnabled) + return false; + return importIndividualNodesInnerOnDragOver(this, e); + } + async onDragDrop(e) { + if (!this.isDropEnabled) + return false; + return importIndividualNodesInnerOnDragDrop(this, e); + } + onConstructed() { + var _a; + if (this.__constructed__) + return false; + this.type = (_a = this.type) !== null && _a !== void 0 ? _a : undefined; + this.__constructed__ = true; + rgthree.invokeExtensionsAsync("nodeCreated", this); + return this.__constructed__; + } + configure(info) { + this.configuring = true; + super.configure(info); + for (const w of this.widgets || []) { + w.last_y = w.last_y || 0; + } + this.configuring = false; + } + clone() { + const cloned = super.clone(); + if (cloned.properties && !!window.structuredClone) { + cloned.properties = structuredClone(cloned.properties); + } + return cloned; + } + set mode(mode) { + if (this.mode_ != mode) { + const oldMode = this.mode_; + this.mode_ = mode; + this.onModeChange(oldMode, mode); + } + } + get mode() { + return this.mode_; + } + onModeChange(from, to) { + } + async handleAction(action) { + action; + } + removeWidget(widgetOrSlot) { + if (typeof widgetOrSlot === "number") { + this.widgets.splice(widgetOrSlot, 1); + } + else if (widgetOrSlot) { + const index = this.widgets.indexOf(widgetOrSlot); + if (index > -1) { + this.widgets.splice(index, 1); + } + } + } + defaultGetSlotMenuOptions(slot) { + var _a, _b; + const menu_info = []; + if ((_b = (_a = slot === null || slot === void 0 ? void 0 : slot.output) === null || _a === void 0 ? void 0 : _a.links) === null || _b === void 0 ? void 0 : _b.length) { + menu_info.push({ content: "Disconnect Links", slot: slot }); + } + let inputOrOutput = slot.input || slot.output; + if (inputOrOutput) { + if (inputOrOutput.removable) { + menu_info.push(inputOrOutput.locked ? { content: "Cannot remove" } : { content: "Remove Slot", slot }); + } + if (!inputOrOutput.nameLocked) { + menu_info.push({ content: "Rename Slot", slot }); + } + } + return menu_info; + } + onRemoved() { + var _a; + (_a = super.onRemoved) === null || _a === void 0 ? void 0 : _a.call(this); + this.removed = true; + } + static setUp(...args) { + } + getHelp() { + return ""; + } + showHelp() { + const help = this.getHelp() || this.constructor.help; + if (help) { + this.helpDialog = new RgthreeHelpDialog(this, help).show(); + this.helpDialog.addEventListener("close", (e) => { + this.helpDialog = null; + }); + } + } + onKeyDown(event) { + KEY_EVENT_SERVICE.handleKeyDownOrUp(event); + if (event.key == "?" && !this.helpDialog) { + this.showHelp(); + } + } + onKeyUp(event) { + KEY_EVENT_SERVICE.handleKeyDownOrUp(event); + } + getExtraMenuOptions(canvas, options) { + var _a, _b, _c, _d, _e, _f; + if (super.getExtraMenuOptions) { + (_a = super.getExtraMenuOptions) === null || _a === void 0 ? void 0 : _a.apply(this, [canvas, options]); + } + else if ((_c = (_b = this.constructor.nodeType) === null || _b === void 0 ? void 0 : _b.prototype) === null || _c === void 0 ? void 0 : _c.getExtraMenuOptions) { + (_f = (_e = (_d = this.constructor.nodeType) === null || _d === void 0 ? void 0 : _d.prototype) === null || _e === void 0 ? void 0 : _e.getExtraMenuOptions) === null || _f === void 0 ? void 0 : _f.apply(this, [ + canvas, + options, + ]); + } + const help = this.getHelp() || this.constructor.help; + if (help) { + addHelpMenuItem(this, help, options); + } + } +} +RgthreeBaseNode.exposedActions = []; +RgthreeBaseNode.title = "__NEED_CLASS_TITLE__"; +RgthreeBaseNode.category = "rgthree"; +RgthreeBaseNode._category = "rgthree"; +export class RgthreeBaseVirtualNode extends RgthreeBaseNode { + constructor(title = RgthreeBaseNode.title) { + super(title, false); + this.isVirtualNode = true; + } + static setUp() { + if (!this.type) { + throw new Error(`Missing type for RgthreeBaseVirtualNode: ${this.title}`); + } + LiteGraph.registerNodeType(this.type, this); + if (this._category) { + this.category = this._category; + } + } +} +export class RgthreeBaseServerNode extends RgthreeBaseNode { + constructor(title) { + super(title, true); + this.isDropEnabled = true; + this.serialize_widgets = true; + this.setupFromServerNodeData(); + this.onConstructed(); + } + getWidgets() { + return ComfyWidgets; + } + async setupFromServerNodeData() { + var _a, _b, _c; + const nodeData = this.constructor.nodeData; + if (!nodeData) { + throw Error("No node data"); + } + this.comfyClass = nodeData.name; + let inputs = nodeData["input"]["required"]; + if (nodeData["input"]["optional"] != undefined) { + inputs = Object.assign({}, inputs, nodeData["input"]["optional"]); + } + const WIDGETS = this.getWidgets(); + const config = { + minWidth: 1, + minHeight: 1, + widget: null, + }; + for (const inputName in inputs) { + const inputData = inputs[inputName]; + const type = inputData[0]; + if ((_a = inputData[1]) === null || _a === void 0 ? void 0 : _a.forceInput) { + this.addInput(inputName, type); + } + else { + let widgetCreated = true; + if (Array.isArray(type)) { + Object.assign(config, WIDGETS.COMBO(this, inputName, inputData, app) || {}); + } + else if (`${type}:${inputName}` in WIDGETS) { + Object.assign(config, WIDGETS[`${type}:${inputName}`](this, inputName, inputData, app) || {}); + } + else if (type in WIDGETS) { + Object.assign(config, WIDGETS[type](this, inputName, inputData, app) || {}); + } + else { + this.addInput(inputName, type); + widgetCreated = false; + } + if (widgetCreated && ((_b = inputData[1]) === null || _b === void 0 ? void 0 : _b.forceInput) && (config === null || config === void 0 ? void 0 : config.widget)) { + if (!config.widget.options) + config.widget.options = {}; + config.widget.options.forceInput = inputData[1].forceInput; + } + if (widgetCreated && ((_c = inputData[1]) === null || _c === void 0 ? void 0 : _c.defaultInput) && (config === null || config === void 0 ? void 0 : config.widget)) { + if (!config.widget.options) + config.widget.options = {}; + config.widget.options.defaultInput = inputData[1].defaultInput; + } + } + } + for (const o in nodeData["output"]) { + let output = nodeData["output"][o]; + if (output instanceof Array) + output = "COMBO"; + const outputName = nodeData["output_name"][o] || output; + const outputShape = nodeData["output_is_list"][o] + ? LiteGraph.GRID_SHAPE + : LiteGraph.CIRCLE_SHAPE; + this.addOutput(outputName, output, { shape: outputShape }); + } + const s = this.computeSize(); + s[0] = Math.max(config.minWidth, s[0] * 1.5); + s[1] = Math.max(config.minHeight, s[1]); + this.size = s; + this.serialize_widgets = true; + } + static registerForOverride(comfyClass, nodeData, rgthreeClass) { + if (OVERRIDDEN_SERVER_NODES.has(comfyClass)) { + throw Error(`Already have a class to override ${comfyClass.type || comfyClass.name || comfyClass.title}`); + } + OVERRIDDEN_SERVER_NODES.set(comfyClass, rgthreeClass); + if (!rgthreeClass.__registeredForOverride__) { + rgthreeClass.__registeredForOverride__ = true; + rgthreeClass.nodeType = comfyClass; + rgthreeClass.nodeData = nodeData; + rgthreeClass.onRegisteredForOverride(comfyClass, rgthreeClass); + } + } + static onRegisteredForOverride(comfyClass, rgthreeClass) { + } +} +RgthreeBaseServerNode.nodeData = null; +RgthreeBaseServerNode.nodeType = null; +RgthreeBaseServerNode.__registeredForOverride__ = false; +const OVERRIDDEN_SERVER_NODES = new Map(); +const oldregisterNodeType = LiteGraph.registerNodeType; +LiteGraph.registerNodeType = async function (nodeId, baseClass) { + var _a; + const clazz = OVERRIDDEN_SERVER_NODES.get(baseClass) || baseClass; + if (clazz !== baseClass) { + const classLabel = clazz.type || clazz.name || clazz.title; + const [n, v] = rgthree.logger.logParts(LogLevel.DEBUG, `${nodeId}: replacing default ComfyNode implementation with custom ${classLabel} class.`); + (_a = console[n]) === null || _a === void 0 ? void 0 : _a.call(console, ...v); + } + return oldregisterNodeType.call(LiteGraph, nodeId, clazz); +}; diff --git a/rgthree-comfy/web/comfyui/base_node_collector.js b/rgthree-comfy/web/comfyui/base_node_collector.js new file mode 100644 index 0000000000000000000000000000000000000000..f79bdafe2e7953e08f3e64e1aac135ad5f9da2f0 --- /dev/null +++ b/rgthree-comfy/web/comfyui/base_node_collector.js @@ -0,0 +1,50 @@ +import { rgthree } from "./rgthree.js"; +import { BaseAnyInputConnectedNode } from "./base_any_input_connected_node.js"; +import { PassThroughFollowing, getConnectedInputNodes, getConnectedInputNodesAndFilterPassThroughs, shouldPassThrough, } from "./utils.js"; +export class BaseCollectorNode extends BaseAnyInputConnectedNode { + constructor(title) { + super(title); + this.inputsPassThroughFollowing = PassThroughFollowing.REROUTE_ONLY; + this.logger = rgthree.newLogSession("[BaseCollectorNode]"); + } + clone() { + const cloned = super.clone(); + return cloned; + } + handleLinkedNodesStabilization(linkedNodes) { + } + onConnectInput(inputIndex, outputType, outputSlot, outputNode, outputIndex) { + var _a, _b, _c, _d; + let canConnect = super.onConnectInput(inputIndex, outputType, outputSlot, outputNode, outputIndex); + if (canConnect) { + const allConnectedNodes = getConnectedInputNodes(this); + const nodesAlreadyInSlot = getConnectedInputNodes(this, undefined, inputIndex); + if (allConnectedNodes.includes(outputNode)) { + const [n, v] = this.logger.debugParts(`${outputNode.title} is already connected to ${this.title}.`); + (_a = console[n]) === null || _a === void 0 ? void 0 : _a.call(console, ...v); + if (nodesAlreadyInSlot.includes(outputNode)) { + const [n, v] = this.logger.debugParts(`... but letting it slide since it's for the same slot.`); + (_b = console[n]) === null || _b === void 0 ? void 0 : _b.call(console, ...v); + } + else { + canConnect = false; + } + } + if (canConnect && shouldPassThrough(outputNode, PassThroughFollowing.REROUTE_ONLY)) { + const connectedNode = getConnectedInputNodesAndFilterPassThroughs(outputNode, undefined, undefined, PassThroughFollowing.REROUTE_ONLY)[0]; + if (connectedNode && allConnectedNodes.includes(connectedNode)) { + const [n, v] = this.logger.debugParts(`${connectedNode.title} is already connected to ${this.title}.`); + (_c = console[n]) === null || _c === void 0 ? void 0 : _c.call(console, ...v); + if (nodesAlreadyInSlot.includes(connectedNode)) { + const [n, v] = this.logger.debugParts(`... but letting it slide since it's for the same slot.`); + (_d = console[n]) === null || _d === void 0 ? void 0 : _d.call(console, ...v); + } + else { + canConnect = false; + } + } + } + } + return canConnect; + } +} diff --git a/rgthree-comfy/web/comfyui/base_node_mode_changer.js b/rgthree-comfy/web/comfyui/base_node_mode_changer.js new file mode 100644 index 0000000000000000000000000000000000000000..f253334bbe24f9124e8c9897a320af2bfccb914c --- /dev/null +++ b/rgthree-comfy/web/comfyui/base_node_mode_changer.js @@ -0,0 +1,84 @@ +import { BaseAnyInputConnectedNode } from "./base_any_input_connected_node.js"; +import { PassThroughFollowing } from "./utils.js"; +import { wait } from "../../rgthree/common/shared_utils.js"; +export class BaseNodeModeChanger extends BaseAnyInputConnectedNode { + constructor(title) { + super(title); + this.inputsPassThroughFollowing = PassThroughFollowing.ALL; + this.isVirtualNode = true; + this.modeOn = -1; + this.modeOff = -1; + this.properties["toggleRestriction"] = "default"; + } + onConstructed() { + wait(10).then(() => { + if (this.modeOn < 0 || this.modeOff < 0) { + throw new Error("modeOn and modeOff must be overridden."); + } + }); + this.addOutput("OPT_CONNECTION", "*"); + return super.onConstructed(); + } + configure(info) { + var _a; + if ((_a = info.outputs) === null || _a === void 0 ? void 0 : _a.length) { + info.outputs.length = 1; + } + super.configure(info); + } + handleLinkedNodesStabilization(linkedNodes) { + for (const [index, node] of linkedNodes.entries()) { + let widget = this.widgets && this.widgets[index]; + if (!widget) { + this._tempWidth = this.size[0]; + widget = this.addWidget("toggle", "", false, "", { on: "yes", off: "no" }); + } + node && this.setWidget(widget, node); + } + if (this.widgets && this.widgets.length > linkedNodes.length) { + this.widgets.length = linkedNodes.length; + } + } + setWidget(widget, linkedNode, forceValue) { + const value = forceValue == null ? linkedNode.mode === this.modeOn : forceValue; + widget.name = `Enable ${linkedNode.title}`; + widget.options = { on: "yes", off: "no" }; + widget.value = value; + widget.doModeChange = (forceValue, skipOtherNodeCheck) => { + var _a, _b, _c; + let newValue = forceValue == null ? linkedNode.mode === this.modeOff : forceValue; + if (skipOtherNodeCheck !== true) { + if (newValue && ((_b = (_a = this.properties) === null || _a === void 0 ? void 0 : _a["toggleRestriction"]) === null || _b === void 0 ? void 0 : _b.includes(" one"))) { + for (const widget of this.widgets) { + widget.doModeChange(false, true); + } + } + else if (!newValue && ((_c = this.properties) === null || _c === void 0 ? void 0 : _c["toggleRestriction"]) === "always one") { + newValue = this.widgets.every((w) => !w.value || w === widget); + } + } + linkedNode.mode = (newValue ? this.modeOn : this.modeOff); + widget.value = newValue; + }; + widget.callback = () => { + widget.doModeChange(); + }; + if (forceValue != null) { + linkedNode.mode = (forceValue ? this.modeOn : this.modeOff); + } + } + forceWidgetOff(widget, skipOtherNodeCheck) { + widget.doModeChange(false, skipOtherNodeCheck); + } + forceWidgetOn(widget, skipOtherNodeCheck) { + widget.doModeChange(true, skipOtherNodeCheck); + } + forceWidgetToggle(widget, skipOtherNodeCheck) { + widget.doModeChange(!widget.value, skipOtherNodeCheck); + } +} +BaseNodeModeChanger.collapsible = false; +BaseNodeModeChanger["@toggleRestriction"] = { + type: "combo", + values: ["default", "max one", "always one"], +}; diff --git a/rgthree-comfy/web/comfyui/base_power_prompt.js b/rgthree-comfy/web/comfyui/base_power_prompt.js new file mode 100644 index 0000000000000000000000000000000000000000..9a5f064a82a3441b6580dbb7146b763b582afd98 --- /dev/null +++ b/rgthree-comfy/web/comfyui/base_power_prompt.js @@ -0,0 +1,251 @@ +import { api } from "../../scripts/api.js"; +import { wait } from "../../rgthree/common/shared_utils.js"; +import { rgthree } from "./rgthree.js"; +export class PowerPrompt { + constructor(node, nodeData) { + this.combos = {}; + this.combosValues = {}; + this.configuring = false; + this.node = node; + this.node.properties = this.node.properties || {}; + this.node.properties["combos_filter"] = ""; + this.nodeData = nodeData; + this.isSimple = this.nodeData.name.includes("Simple"); + this.promptEl = node.widgets[0].inputEl; + this.addAndHandleKeyboardLoraEditWeight(); + this.patchNodeRefresh(); + const oldConfigure = this.node.configure; + this.node.configure = (info) => { + this.configuring = true; + oldConfigure === null || oldConfigure === void 0 ? void 0 : oldConfigure.apply(this.node, [info]); + this.configuring = false; + }; + const oldOnConnectionsChange = this.node.onConnectionsChange; + this.node.onConnectionsChange = (type, slotIndex, isConnected, link_info, _ioSlot) => { + oldOnConnectionsChange === null || oldOnConnectionsChange === void 0 ? void 0 : oldOnConnectionsChange.apply(this.node, [type, slotIndex, isConnected, link_info, _ioSlot]); + this.onNodeConnectionsChange(type, slotIndex, isConnected, link_info, _ioSlot); + }; + const oldOnConnectInput = this.node.onConnectInput; + this.node.onConnectInput = (inputIndex, outputType, outputSlot, outputNode, outputIndex) => { + let canConnect = true; + if (oldOnConnectInput) { + canConnect = oldOnConnectInput.apply(this.node, [ + inputIndex, + outputType, + outputSlot, + outputNode, + outputIndex, + ]); + } + return (this.configuring || + rgthree.loadingApiJson || + (canConnect && !this.node.inputs[inputIndex].disabled)); + }; + const oldOnConnectOutput = this.node.onConnectOutput; + this.node.onConnectOutput = (outputIndex, inputType, inputSlot, inputNode, inputIndex) => { + let canConnect = true; + if (oldOnConnectOutput) { + canConnect = oldOnConnectOutput === null || oldOnConnectOutput === void 0 ? void 0 : oldOnConnectOutput.apply(this.node, [ + outputIndex, + inputType, + inputSlot, + inputNode, + inputIndex, + ]); + } + return (this.configuring || + rgthree.loadingApiJson || + (canConnect && !this.node.outputs[outputIndex].disabled)); + }; + const onPropertyChanged = this.node.onPropertyChanged; + this.node.onPropertyChanged = (property, value, prevValue) => { + onPropertyChanged && onPropertyChanged.call(this, property, value, prevValue); + if (property === "combos_filter") { + this.refreshCombos(this.nodeData); + } + }; + for (let i = this.node.widgets.length - 1; i >= 0; i--) { + if (this.shouldRemoveServerWidget(this.node.widgets[i])) { + this.node.widgets.splice(i, 1); + } + } + this.refreshCombos(nodeData); + setTimeout(() => { + this.stabilizeInputsOutputs(); + }, 32); + } + onNodeConnectionsChange(_type, _slotIndex, _isConnected, _linkInfo, _ioSlot) { + this.stabilizeInputsOutputs(); + } + stabilizeInputsOutputs() { + if (this.configuring || rgthree.loadingApiJson) { + return; + } + const clipLinked = this.node.inputs.some((i) => i.name.includes("clip") && !!i.link); + const modelLinked = this.node.inputs.some((i) => i.name.includes("model") && !!i.link); + for (const output of this.node.outputs) { + const type = output.type.toLowerCase(); + if (type.includes("model")) { + output.disabled = !modelLinked; + } + else if (type.includes("conditioning")) { + output.disabled = !clipLinked; + } + else if (type.includes("clip")) { + output.disabled = !clipLinked; + } + else if (type.includes("string")) { + output.color_off = "#7F7"; + output.color_on = "#7F7"; + } + if (output.disabled) { + } + } + } + onFreshNodeDefs(event) { + this.refreshCombos(event.detail[this.nodeData.name]); + } + shouldRemoveServerWidget(widget) { + var _a, _b, _c, _d; + return (((_a = widget.name) === null || _a === void 0 ? void 0 : _a.startsWith("insert_")) || + ((_b = widget.name) === null || _b === void 0 ? void 0 : _b.startsWith("target_")) || + ((_c = widget.name) === null || _c === void 0 ? void 0 : _c.startsWith("crop_")) || + ((_d = widget.name) === null || _d === void 0 ? void 0 : _d.startsWith("values_"))); + } + refreshCombos(nodeData) { + var _a, _b, _c; + this.nodeData = nodeData; + let filter = null; + if ((_a = this.node.properties["combos_filter"]) === null || _a === void 0 ? void 0 : _a.trim()) { + try { + filter = new RegExp(this.node.properties["combos_filter"].trim(), "i"); + } + catch (e) { + console.error(`Could not parse "${filter}" for Regular Expression`, e); + filter = null; + } + } + let data = Object.assign({}, ((_b = this.nodeData.input) === null || _b === void 0 ? void 0 : _b.optional) || {}, ((_c = this.nodeData.input) === null || _c === void 0 ? void 0 : _c.hidden) || {}); + for (const [key, value] of Object.entries(data)) { + if (Array.isArray(value[0])) { + let values = value[0]; + if (key.startsWith("insert")) { + values = filter + ? values.filter((v, i) => i < 1 || (i == 1 && v.match(/^disable\s[a-z]/i)) || (filter === null || filter === void 0 ? void 0 : filter.test(v))) + : values; + const shouldShow = values.length > 2 || (values.length > 1 && !values[1].match(/^disable\s[a-z]/i)); + if (shouldShow) { + if (!this.combos[key]) { + this.combos[key] = this.node.addWidget("combo", key, values, (selected) => { + if (selected !== values[0] && !selected.match(/^disable\s[a-z]/i)) { + wait().then(() => { + if (key.includes("embedding")) { + this.insertSelectionText(`embedding:${selected}`); + } + else if (key.includes("saved")) { + this.insertSelectionText(this.combosValues[`values_${key}`][values.indexOf(selected)]); + } + else if (key.includes("lora")) { + this.insertSelectionText(``); + } + this.combos[key].value = values[0]; + }); + } + }, { + values, + serialize: true, + }); + this.combos[key].oldComputeSize = this.combos[key].computeSize; + let node = this.node; + this.combos[key].computeSize = function (width) { + var _a, _b; + const size = ((_b = (_a = this).oldComputeSize) === null || _b === void 0 ? void 0 : _b.call(_a, width)) || [ + width, + LiteGraph.NODE_WIDGET_HEIGHT, + ]; + if (this === node.widgets[node.widgets.length - 1]) { + size[1] += 10; + } + return size; + }; + } + this.combos[key].options.values = values; + this.combos[key].value = values[0]; + } + else if (!shouldShow && this.combos[key]) { + this.node.widgets.splice(this.node.widgets.indexOf(this.combos[key]), 1); + delete this.combos[key]; + } + } + else if (key.startsWith("values")) { + this.combosValues[key] = values; + } + } + } + } + insertSelectionText(text) { + if (!this.promptEl) { + console.error("Asked to insert text, but no textbox found."); + return; + } + let prompt = this.promptEl.value; + let first = prompt.substring(0, this.promptEl.selectionEnd).replace(/ +$/, ""); + first = first + (["\n"].includes(first[first.length - 1]) ? "" : first.length ? " " : ""); + let second = prompt.substring(this.promptEl.selectionEnd).replace(/^ +/, ""); + second = (["\n"].includes(second[0]) ? "" : second.length ? " " : "") + second; + this.promptEl.value = first + text + second; + this.promptEl.focus(); + this.promptEl.selectionStart = first.length; + this.promptEl.selectionEnd = first.length + text.length; + } + addAndHandleKeyboardLoraEditWeight() { + this.promptEl.addEventListener("keydown", (event) => { + var _a, _b; + if (!(event.key === "ArrowUp" || event.key === "ArrowDown")) + return; + if (!event.ctrlKey && !event.metaKey) + return; + const delta = event.shiftKey ? 0.01 : 0.1; + let start = this.promptEl.selectionStart; + let end = this.promptEl.selectionEnd; + let fullText = this.promptEl.value; + let selectedText = fullText.substring(start, end); + if (!selectedText) { + const stopOn = "<>()\r\n\t"; + if (fullText[start] == ">") { + start -= 2; + end -= 2; + } + if (fullText[end - 1] == "<") { + start += 2; + end += 2; + } + while (!stopOn.includes(fullText[start]) && start > 0) { + start--; + } + while (!stopOn.includes(fullText[end - 1]) && end < fullText.length) { + end++; + } + selectedText = fullText.substring(start, end); + } + if (!selectedText.startsWith("")) { + return; + } + let weight = (_b = Number((_a = selectedText.match(/:(-?\d*(\.\d*)?)>$/)) === null || _a === void 0 ? void 0 : _a[1])) !== null && _b !== void 0 ? _b : 1; + weight += event.key === "ArrowUp" ? delta : -delta; + const updatedText = selectedText.replace(/(:-?\d*(\.\d*)?)?>$/, `:${weight.toFixed(2)}>`); + this.promptEl.setRangeText(updatedText, start, end, "select"); + event.preventDefault(); + event.stopPropagation(); + }); + } + patchNodeRefresh() { + this.boundOnFreshNodeDefs = this.onFreshNodeDefs.bind(this); + api.addEventListener("fresh-node-defs", this.boundOnFreshNodeDefs); + const oldNodeRemoved = this.node.onRemoved; + this.node.onRemoved = () => { + oldNodeRemoved === null || oldNodeRemoved === void 0 ? void 0 : oldNodeRemoved.call(this.node); + api.removeEventListener("fresh-node-defs", this.boundOnFreshNodeDefs); + }; + } +} diff --git a/rgthree-comfy/web/comfyui/bookmark.js b/rgthree-comfy/web/comfyui/bookmark.js new file mode 100644 index 0000000000000000000000000000000000000000..4935a578ad3d9d40c8ca57b9acdaf0167ca05fb7 --- /dev/null +++ b/rgthree-comfy/web/comfyui/bookmark.js @@ -0,0 +1,110 @@ +import { app } from "../../scripts/app.js"; +import { RgthreeBaseVirtualNode } from "./base_node.js"; +import { SERVICE as KEY_EVENT_SERVICE } from "./services/key_events_services.js"; +import { NodeTypesString } from "./constants.js"; +import { getClosestOrSelf, queryOne } from "../../rgthree/common/utils_dom.js"; +export class Bookmark extends RgthreeBaseVirtualNode { + get _collapsed_width() { + return this.___collapsed_width; + } + set _collapsed_width(width) { + const canvas = app.canvas; + const ctx = canvas.canvas.getContext("2d"); + const oldFont = ctx.font; + ctx.font = canvas.title_text_font; + this.___collapsed_width = 40 + ctx.measureText(this.title).width; + ctx.font = oldFont; + } + constructor(title = Bookmark.title) { + super(title); + this.comfyClass = NodeTypesString.BOOKMARK; + this.___collapsed_width = 0; + this.isVirtualNode = true; + this.serialize_widgets = true; + const nextShortcutChar = getNextShortcut(); + this.addWidget("text", "shortcut_key", nextShortcutChar, (value, ...args) => { + value = value.trim()[0] || "1"; + }, { + y: 8, + }); + this.addWidget("number", "zoom", 1, (value) => { }, { + y: 8 + LiteGraph.NODE_WIDGET_HEIGHT + 4, + max: 2, + min: 0.5, + precision: 2, + }); + this.keypressBound = this.onKeypress.bind(this); + this.title = "🔖"; + this.onConstructed(); + } + get shortcutKey() { + var _a, _b, _c; + return (_c = (_b = (_a = this.widgets[0]) === null || _a === void 0 ? void 0 : _a.value) === null || _b === void 0 ? void 0 : _b.toLocaleLowerCase()) !== null && _c !== void 0 ? _c : ""; + } + onAdded(graph) { + KEY_EVENT_SERVICE.addEventListener("keydown", this.keypressBound); + } + onRemoved() { + KEY_EVENT_SERVICE.removeEventListener("keydown", this.keypressBound); + } + onKeypress(event) { + const originalEvent = event.detail.originalEvent; + const target = originalEvent.target; + if (getClosestOrSelf(target, 'input,textarea,[contenteditable="true"]')) { + return; + } + if (KEY_EVENT_SERVICE.areOnlyKeysDown(this.widgets[0].value, true)) { + this.canvasToBookmark(); + originalEvent.preventDefault(); + originalEvent.stopPropagation(); + } + } + onMouseDown(event, pos, graphCanvas) { + var _a; + const input = queryOne(".graphdialog > input.value"); + if (input && input.value === ((_a = this.widgets[0]) === null || _a === void 0 ? void 0 : _a.value)) { + input.addEventListener("keydown", (e) => { + KEY_EVENT_SERVICE.handleKeyDownOrUp(e); + e.preventDefault(); + e.stopPropagation(); + input.value = Object.keys(KEY_EVENT_SERVICE.downKeys).join(" + "); + }); + } + } + canvasToBookmark() { + var _a, _b; + const canvas = app.canvas; + if ((_a = canvas === null || canvas === void 0 ? void 0 : canvas.ds) === null || _a === void 0 ? void 0 : _a.offset) { + canvas.ds.offset[0] = -this.pos[0] + 16; + canvas.ds.offset[1] = -this.pos[1] + 40; + } + if (((_b = canvas === null || canvas === void 0 ? void 0 : canvas.ds) === null || _b === void 0 ? void 0 : _b.scale) != null) { + canvas.ds.scale = Number(this.widgets[1].value || 1); + } + canvas.setDirty(true, true); + } +} +Bookmark.type = NodeTypesString.BOOKMARK; +Bookmark.title = NodeTypesString.BOOKMARK; +Bookmark.slot_start_y = -20; +app.registerExtension({ + name: "rgthree.Bookmark", + registerCustomNodes() { + Bookmark.setUp(); + }, +}); +function isBookmark(node) { + return node.type === NodeTypesString.BOOKMARK; +} +function getExistingShortcuts() { + const graph = app.graph; + const bookmarkNodes = graph._nodes.filter(isBookmark); + const usedShortcuts = new Set(bookmarkNodes.map((n) => n.shortcutKey)); + return usedShortcuts; +} +const SHORTCUT_DEFAULTS = "1234567890abcdefghijklmnopqrstuvwxyz".split(""); +function getNextShortcut() { + var _a; + const existingShortcuts = getExistingShortcuts(); + return (_a = SHORTCUT_DEFAULTS.find((char) => !existingShortcuts.has(char))) !== null && _a !== void 0 ? _a : "1"; +} diff --git a/rgthree-comfy/web/comfyui/bypasser.js b/rgthree-comfy/web/comfyui/bypasser.js new file mode 100644 index 0000000000000000000000000000000000000000..3f97178ad0183e83568e38ffb2f030659f9bc33c --- /dev/null +++ b/rgthree-comfy/web/comfyui/bypasser.js @@ -0,0 +1,45 @@ +import { app } from "../../scripts/app.js"; +import { BaseNodeModeChanger } from "./base_node_mode_changer.js"; +import { NodeTypesString } from "./constants.js"; +const MODE_BYPASS = 4; +const MODE_ALWAYS = 0; +class BypasserNode extends BaseNodeModeChanger { + constructor(title = BypasserNode.title) { + super(title); + this.comfyClass = NodeTypesString.FAST_BYPASSER; + this.modeOn = MODE_ALWAYS; + this.modeOff = MODE_BYPASS; + this.onConstructed(); + } + async handleAction(action) { + if (action === "Bypass all") { + for (const widget of this.widgets) { + this.forceWidgetOff(widget, true); + } + } + else if (action === "Enable all") { + for (const widget of this.widgets) { + this.forceWidgetOn(widget, true); + } + } + else if (action === "Toggle all") { + for (const widget of this.widgets) { + this.forceWidgetToggle(widget, true); + } + } + } +} +BypasserNode.exposedActions = ["Bypass all", "Enable all", "Toggle all"]; +BypasserNode.type = NodeTypesString.FAST_BYPASSER; +BypasserNode.title = NodeTypesString.FAST_BYPASSER; +app.registerExtension({ + name: "rgthree.Bypasser", + registerCustomNodes() { + BypasserNode.setUp(); + }, + loadedGraphNode(node) { + if (node.type == BypasserNode.title) { + node._tempWidth = node.size[0]; + } + }, +}); diff --git a/rgthree-comfy/web/comfyui/comfy_ui_bar.js b/rgthree-comfy/web/comfyui/comfy_ui_bar.js new file mode 100644 index 0000000000000000000000000000000000000000..1bcfa76816805ebf347b86d1d242218c21c888b3 --- /dev/null +++ b/rgthree-comfy/web/comfyui/comfy_ui_bar.js @@ -0,0 +1,100 @@ +import { app } from "../../scripts/app.js"; +import { ComfyButtonGroup } from "../../scripts/ui/components/buttonGroup.js"; +import { ComfyButton } from "../../scripts/ui/components/button.js"; +import { iconGear, iconStarFilled, logoRgthree } from "../../rgthree/common/media/svgs.js"; +import { createElement, empty } from "../../rgthree/common/utils_dom.js"; +import { SERVICE as BOOKMARKS_SERVICE } from "./services/bookmarks_services.js"; +import { SERVICE as CONFIG_SERVICE } from "./services/config_service.js"; +import { ComfyPopup } from "../../scripts/ui/components/popup.js"; +import { RgthreeConfigDialog } from "./config.js"; +let rgthreeButtonGroup = null; +function addRgthreeTopBarButtons() { + var _a, _b, _c; + if (!CONFIG_SERVICE.getFeatureValue("comfy_top_bar_menu.enabled")) { + if ((_a = rgthreeButtonGroup === null || rgthreeButtonGroup === void 0 ? void 0 : rgthreeButtonGroup.element) === null || _a === void 0 ? void 0 : _a.parentElement) { + rgthreeButtonGroup.element.parentElement.removeChild(rgthreeButtonGroup.element); + } + return; + } + else if (rgthreeButtonGroup) { + (_b = app.menu) === null || _b === void 0 ? void 0 : _b.settingsGroup.element.before(rgthreeButtonGroup.element); + return; + } + const buttons = []; + const rgthreeButton = new ComfyButton({ + icon: "rgthree", + tooltip: "rgthree-comfy", + app, + enabled: true, + classList: "comfyui-button comfyui-menu-mobile-collapse primary", + }); + buttons.push(rgthreeButton); + rgthreeButton.iconElement.style.width = "1.2rem"; + rgthreeButton.iconElement.innerHTML = logoRgthree; + rgthreeButton.withPopup(new ComfyPopup({ target: rgthreeButton.element, classList: "rgthree-top-menu" }, createElement("menu", { + children: [ + createElement("li", { + child: createElement("button.rgthree-button-reset", { + html: iconGear + "Settings (rgthree-comfy)", + onclick: () => new RgthreeConfigDialog().show(), + }), + }), + createElement("li", { + child: createElement("button.rgthree-button-reset", { + html: iconStarFilled + "Star on Github", + onclick: () => window.open("https://github.com/rgthree/rgthree-comfy", "_blank"), + }), + }), + ], + })), "click"); + if (CONFIG_SERVICE.getFeatureValue("comfy_top_bar_menu.button_bookmarks.enabled")) { + const bookmarksListEl = createElement("menu"); + bookmarksListEl.appendChild(createElement("li.rgthree-message", { + child: createElement("span", { text: "No bookmarks in current workflow." }), + })); + const bookmarksButton = new ComfyButton({ + icon: "bookmark", + tooltip: "Workflow Bookmarks (rgthree-comfy)", + app, + }); + const bookmarksPopup = new ComfyPopup({ target: bookmarksButton.element, classList: "rgthree-top-menu" }, bookmarksListEl); + bookmarksPopup.addEventListener("open", () => { + const bookmarks = BOOKMARKS_SERVICE.getCurrentBookmarks(); + empty(bookmarksListEl); + if (bookmarks.length) { + for (const b of bookmarks) { + bookmarksListEl.appendChild(createElement("li", { + child: createElement("button.rgthree-button-reset", { + text: `[${b.shortcutKey}] ${b.title}`, + onclick: () => { + b.canvasToBookmark(); + }, + }), + })); + } + } + else { + bookmarksListEl.appendChild(createElement("li.rgthree-message", { + child: createElement("span", { text: "No bookmarks in current workflow." }), + })); + } + bookmarksPopup.update(); + }); + bookmarksButton.withPopup(bookmarksPopup, "hover"); + buttons.push(bookmarksButton); + } + rgthreeButtonGroup = new ComfyButtonGroup(...buttons); + (_c = app.menu) === null || _c === void 0 ? void 0 : _c.settingsGroup.element.before(rgthreeButtonGroup.element); +} +app.registerExtension({ + name: "rgthree.TopMenu", + async setup() { + addRgthreeTopBarButtons(); + CONFIG_SERVICE.addEventListener("config-change", ((e) => { + var _a, _b; + if ((_b = (_a = e.detail) === null || _a === void 0 ? void 0 : _a.key) === null || _b === void 0 ? void 0 : _b.includes("features.comfy_top_bar_menu")) { + addRgthreeTopBarButtons(); + } + })); + }, +}); diff --git a/rgthree-comfy/web/comfyui/config.js b/rgthree-comfy/web/comfyui/config.js new file mode 100644 index 0000000000000000000000000000000000000000..fe767194d182da7e96e3d0ab5c101dfc91b90d12 --- /dev/null +++ b/rgthree-comfy/web/comfyui/config.js @@ -0,0 +1,336 @@ +import { app } from "../../scripts/app.js"; +import { RgthreeDialog } from "../../rgthree/common/dialog.js"; +import { createElement as $el, query as $$ } from "../../rgthree/common/utils_dom.js"; +import { checkmark, logoRgthree } from "../../rgthree/common/media/svgs.js"; +import { rgthree } from "./rgthree.js"; +import { SERVICE as CONFIG_SERVICE } from "./services/config_service.js"; +var ConfigType; +(function (ConfigType) { + ConfigType[ConfigType["UNKNOWN"] = 0] = "UNKNOWN"; + ConfigType[ConfigType["BOOLEAN"] = 1] = "BOOLEAN"; + ConfigType[ConfigType["STRING"] = 2] = "STRING"; + ConfigType[ConfigType["NUMBER"] = 3] = "NUMBER"; + ConfigType[ConfigType["ARRAY"] = 4] = "ARRAY"; +})(ConfigType || (ConfigType = {})); +const TYPE_TO_STRING = { + [ConfigType.UNKNOWN]: "unknown", + [ConfigType.BOOLEAN]: "boolean", + [ConfigType.STRING]: "string", + [ConfigType.NUMBER]: "number", + [ConfigType.ARRAY]: "array", +}; +const CONFIGURABLE = { + features: [ + { + key: "features.patch_recursive_execution", + type: ConfigType.BOOLEAN, + label: "Optimize ComfyUI's Execution", + description: "Patches ComfyUI's backend execution making complex workflows 1000's of times faster." + + "
⚠️ Disable if execution seems broken due to forward ComfyUI changes.", + }, + { + key: "features.progress_bar.enabled", + type: ConfigType.BOOLEAN, + label: "Prompt Progress Bar", + description: `Shows a minimal progress bar for nodes and steps at the top of the app.`, + subconfig: [ + { + key: "features.progress_bar.height", + type: ConfigType.NUMBER, + label: "Height of the bar", + }, + { + key: "features.progress_bar.position", + type: ConfigType.STRING, + label: "Position at top or bottom of window", + options: ["top", "bottom"], + }, + ], + }, + { + key: "features.import_individual_nodes.enabled", + type: ConfigType.BOOLEAN, + label: "Import Individual Nodes Widgets", + description: "Dragging & Dropping a similar image/JSON workflow onto (most) current workflow nodes" + + "will allow you to import that workflow's node's widgets when it has the same " + + "id and type. This is useful when you have several images and you'd like to import just " + + "one part of a previous iteration, like a seed, or prompt.", + }, + ], + menus: [ + { + key: "features.comfy_top_bar_menu.enabled", + type: ConfigType.BOOLEAN, + label: "Enable Top Bar Menu", + description: "Have quick access from ComfyUI's new top bar to rgthree-comfy bookmarks, settings " + + "(and more to come).", + }, + { + key: "features.menu_queue_selected_nodes", + type: ConfigType.BOOLEAN, + label: "Show 'Queue Selected Output Nodes'", + description: "Will show a menu item in the right-click context menus to queue (only) the selected " + + "output nodes.", + }, + { + key: "features.menu_auto_nest.subdirs", + type: ConfigType.BOOLEAN, + label: "Auto Nest Subdirectories in Menus", + description: "When a large, flat list of values contain sub-directories, auto nest them. (Like, for " + + "a large list of checkpoints).", + subconfig: [ + { + key: "features.menu_auto_nest.threshold", + type: ConfigType.NUMBER, + label: "Number of items needed to trigger nesting.", + }, + ], + }, + { + key: "features.menu_bookmarks.enabled", + type: ConfigType.BOOLEAN, + label: "Show Bookmarks in context menu", + description: "Will list bookmarks in the rgthree-comfy right-click context menu.", + }, + ], + groups: [ + { + key: "features.group_header_fast_toggle.enabled", + type: ConfigType.BOOLEAN, + label: "Show fast toggles in Group Headers", + description: "Show quick toggles in Groups' Headers to quickly mute and/or bypass.", + subconfig: [ + { + key: "features.group_header_fast_toggle.toggles", + type: ConfigType.ARRAY, + label: "Which toggles to show.", + options: [ + { value: ["mute"], label: "mute only" }, + { value: ["bypass"], label: "bypass only" }, + { value: ["mute", "bypass"], label: "mute and bypass" }, + ], + }, + { + key: "features.group_header_fast_toggle.show", + type: ConfigType.STRING, + label: "When to show them.", + options: [ + { value: "hover", label: "on hover" }, + { value: "always", label: "always" }, + ], + }, + ], + }, + ], + advanced: [ + { + key: "features.show_alerts_for_corrupt_workflows", + type: ConfigType.BOOLEAN, + label: "Detect Corrupt Workflows", + description: "Will show a message at the top of the screen when loading a workflow that has " + + "corrupt linking data.", + }, + { + key: "log_level", + type: ConfigType.STRING, + label: "Log level for browser dev console.", + description: "Further down the list, the more verbose logs to the console will be. For instance, " + + "selecting 'IMPORTANT' means only important message will be logged to the browser " + + "console, while selecting 'WARN' will log all messages at or higher than WARN, including " + + "'ERROR' and 'IMPORTANT' etc.", + options: ["IMPORTANT", "ERROR", "WARN", "INFO", "DEBUG", "DEV"], + isDevOnly: true, + onSave: function (value) { + rgthree.setLogLevel(value); + }, + }, + { + key: "features.invoke_extensions_async.node_created", + type: ConfigType.BOOLEAN, + label: "Allow other extensions to call nodeCreated on rgthree-nodes.", + isDevOnly: true, + description: "Do not disable unless you are having trouble (and then file an issue at rgthree-comfy)." + + "Prior to Apr 2024 it was not possible for other extensions to invoke their nodeCreated " + + "event on some rgthree-comfy nodes. Now it's possible and this option is only here in " + + "for easy if something is wrong.", + }, + ], +}; +function fieldrow(item) { + var _a; + const initialValue = CONFIG_SERVICE.getConfigValue(item.key); + const container = $el(`div.fieldrow.-type-${TYPE_TO_STRING[item.type]}`, { + dataset: { + name: item.key, + initial: initialValue, + type: item.type, + }, + }); + $el(`label[for="${item.key}"]`, { + children: [ + $el(`span[text="${item.label}"]`), + item.description ? $el("small", { html: item.description }) : null, + ], + parent: container, + }); + let input; + if ((_a = item.options) === null || _a === void 0 ? void 0 : _a.length) { + input = $el(`select[id="${item.key}"]`, { + parent: container, + children: item.options.map((o) => { + const label = o.label || String(o); + const value = o.value || o; + const valueSerialized = JSON.stringify({ value: value }); + return $el(`option[value="${valueSerialized}"]`, { + text: label, + selected: valueSerialized === JSON.stringify({ value: initialValue }), + }); + }), + }); + } + else if (item.type === ConfigType.BOOLEAN) { + container.classList.toggle("-checked", !!initialValue); + input = $el(`input[type="checkbox"][id="${item.key}"]`, { + parent: container, + checked: initialValue, + }); + } + else { + input = $el(`input[id="${item.key}"]`, { + parent: container, + value: initialValue, + }); + } + $el("div.fieldrow-value", { children: [input], parent: container }); + return container; +} +export class RgthreeConfigDialog extends RgthreeDialog { + constructor() { + const content = $el("div"); + content.appendChild(RgthreeConfigDialog.buildFieldset(CONFIGURABLE["features"], "Features")); + content.appendChild(RgthreeConfigDialog.buildFieldset(CONFIGURABLE["menus"], "Menus")); + content.appendChild(RgthreeConfigDialog.buildFieldset(CONFIGURABLE["groups"], "Groups")); + content.appendChild(RgthreeConfigDialog.buildFieldset(CONFIGURABLE["advanced"], "Advanced")); + content.addEventListener("input", (e) => { + const changed = this.getChangedFormData(); + $$(".save-button", this.element)[0].disabled = + !Object.keys(changed).length; + }); + content.addEventListener("change", (e) => { + const changed = this.getChangedFormData(); + $$(".save-button", this.element)[0].disabled = + !Object.keys(changed).length; + }); + const dialogOptions = { + class: "-iconed -settings", + title: logoRgthree + `

Settings - rgthree-comfy

`, + content, + onBeforeClose: () => { + const changed = this.getChangedFormData(); + if (Object.keys(changed).length) { + return confirm("Looks like there are unsaved changes. Are you sure you want close?"); + } + return true; + }, + buttons: [ + { + label: "Save", + disabled: true, + className: "rgthree-button save-button -blue", + callback: async (e) => { + var _a, _b; + const changed = this.getChangedFormData(); + if (!Object.keys(changed).length) { + this.close(); + return; + } + const success = await CONFIG_SERVICE.setConfigValues(changed); + if (success) { + for (const key of Object.keys(changed)) { + (_b = (_a = Object.values(CONFIGURABLE) + .flat() + .find((f) => f.key === key)) === null || _a === void 0 ? void 0 : _a.onSave) === null || _b === void 0 ? void 0 : _b.call(_a, changed[key]); + } + this.close(); + rgthree.showMessage({ + id: "config-success", + message: `${checkmark} Successfully saved rgthree-comfy settings!`, + timeout: 4000, + }); + $$(".save-button", this.element)[0].disabled = true; + } + else { + alert("There was an error saving rgthree-comfy configuration."); + } + }, + }, + ], + }; + super(dialogOptions); + } + static buildFieldset(datas, label) { + const fieldset = $el(`fieldset`, { children: [$el(`legend[text="${label}"]`)] }); + for (const data of datas) { + if (data.isDevOnly && !rgthree.isDevMode()) { + continue; + } + const container = $el("div.formrow"); + container.appendChild(fieldrow(data)); + if (data.subconfig) { + for (const subfeature of data.subconfig) { + container.appendChild(fieldrow(subfeature)); + } + } + fieldset.appendChild(container); + } + return fieldset; + } + getChangedFormData() { + return $$("[data-name]", this.contentElement).reduce((acc, el) => { + const name = el.dataset["name"]; + const type = el.dataset["type"]; + const initialValue = CONFIG_SERVICE.getConfigValue(name); + let currentValueEl = $$("input, textarea, select", el)[0]; + let currentValue = null; + if (type === String(ConfigType.BOOLEAN)) { + currentValue = currentValueEl.checked; + el.classList.toggle("-checked", currentValue); + } + else { + currentValue = currentValueEl === null || currentValueEl === void 0 ? void 0 : currentValueEl.value; + if (currentValueEl.nodeName === "SELECT") { + currentValue = JSON.parse(currentValue).value; + } + else if (type === String(ConfigType.NUMBER)) { + currentValue = Number(currentValue) || initialValue; + } + } + if (JSON.stringify(currentValue) !== JSON.stringify(initialValue)) { + acc[name] = currentValue; + } + return acc; + }, {}); + } +} +app.ui.settings.addSetting({ + id: "rgthree.config", + name: "Open rgthree-comfy config", + type: () => { + return $el("tr.rgthree-comfyui-settings-row", { + children: [ + $el("td", { + child: `
${logoRgthree} [rgthree-comfy] configuration / settings
`, + }), + $el("td", { + child: $el('button.rgthree-button.-blue[text="rgthree-comfy settings"]', { + events: { + click: (e) => { + new RgthreeConfigDialog().show(); + }, + }, + }), + }), + ], + }); + }, +}); diff --git a/rgthree-comfy/web/comfyui/constants.js b/rgthree-comfy/web/comfyui/constants.js new file mode 100644 index 0000000000000000000000000000000000000000..cc506af6eb75ca96ee79ccd8d089f6aac9c82874 --- /dev/null +++ b/rgthree-comfy/web/comfyui/constants.js @@ -0,0 +1,53 @@ +import { SERVICE as CONFIG_SERVICE } from "./services/config_service.js"; +export function addRgthree(str) { + return str + " (rgthree)"; +} +export function stripRgthree(str) { + return str.replace(/\s*\(rgthree\)$/, ""); +} +export const NodeTypesString = { + ANY_SWITCH: addRgthree("Any Switch"), + CONTEXT: addRgthree("Context"), + CONTEXT_BIG: addRgthree("Context Big"), + CONTEXT_SWITCH: addRgthree("Context Switch"), + CONTEXT_SWITCH_BIG: addRgthree("Context Switch Big"), + CONTEXT_MERGE: addRgthree("Context Merge"), + CONTEXT_MERGE_BIG: addRgthree("Context Merge Big"), + DYNAMIC_CONTEXT: addRgthree("Dynamic Context"), + DYNAMIC_CONTEXT_SWITCH: addRgthree("Dynamic Context Switch"), + DISPLAY_ANY: addRgthree("Display Any"), + NODE_MODE_RELAY: addRgthree("Mute / Bypass Relay"), + NODE_MODE_REPEATER: addRgthree("Mute / Bypass Repeater"), + FAST_MUTER: addRgthree("Fast Muter"), + FAST_BYPASSER: addRgthree("Fast Bypasser"), + FAST_GROUPS_MUTER: addRgthree("Fast Groups Muter"), + FAST_GROUPS_BYPASSER: addRgthree("Fast Groups Bypasser"), + FAST_ACTIONS_BUTTON: addRgthree("Fast Actions Button"), + LABEL: addRgthree("Label"), + POWER_PROMPT: addRgthree("Power Prompt"), + POWER_PROMPT_SIMPLE: addRgthree("Power Prompt - Simple"), + SDXL_EMPTY_LATENT_IMAGE: addRgthree("SDXL Empty Latent Image"), + SDXL_POWER_PROMPT_POSITIVE: addRgthree("SDXL Power Prompt - Positive"), + SDXL_POWER_PROMPT_NEGATIVE: addRgthree("SDXL Power Prompt - Simple / Negative"), + POWER_LORA_LOADER: addRgthree("Power Lora Loader"), + KSAMPLER_CONFIG: addRgthree("KSampler Config"), + NODE_COLLECTOR: addRgthree("Node Collector"), + REROUTE: addRgthree("Reroute"), + RANDOM_UNMUTER: addRgthree("Random Unmuter"), + SEED: addRgthree("Seed"), + BOOKMARK: addRgthree("Bookmark"), + IMAGE_COMPARER: addRgthree("Image Comparer"), + IMAGE_INSET_CROP: addRgthree("Image Inset Crop"), +}; +export function getNodeTypeStrings() { + return Object.values(NodeTypesString) + .map((i) => stripRgthree(i)) + .filter((i) => { + if (i.startsWith("Dynamic Context") && + !CONFIG_SERVICE.getConfigValue("unreleased.dynamic_context.enabled")) { + return false; + } + return true; + }) + .sort(); +} diff --git a/rgthree-comfy/web/comfyui/context.js b/rgthree-comfy/web/comfyui/context.js new file mode 100644 index 0000000000000000000000000000000000000000..3df222e59b16ff5ecc75923b6d8081d2b11a1bf9 --- /dev/null +++ b/rgthree-comfy/web/comfyui/context.js @@ -0,0 +1,323 @@ +import { app } from "../../scripts/app.js"; +import { IoDirection, addConnectionLayoutSupport, addMenuItem, matchLocalSlotsToServer, replaceNode, } from "./utils.js"; +import { RgthreeBaseServerNode } from "./base_node.js"; +import { SERVICE as KEY_EVENT_SERVICE } from "./services/key_events_services.js"; +import { debounce, wait } from "../../rgthree/common/shared_utils.js"; +import { removeUnusedInputsFromEnd } from "./utils_inputs_outputs.js"; +import { NodeTypesString } from "./constants.js"; +function findMatchingIndexByTypeOrName(otherNode, otherSlot, ctxSlots) { + const otherNodeType = (otherNode.type || "").toUpperCase(); + const otherNodeName = (otherNode.title || "").toUpperCase(); + let otherSlotType = otherSlot.type; + if (Array.isArray(otherSlotType) || otherSlotType.includes(",")) { + otherSlotType = "COMBO"; + } + const otherSlotName = otherSlot.name.toUpperCase().replace("OPT_", "").replace("_NAME", ""); + let ctxSlotIndex = -1; + if (["CONDITIONING", "INT", "STRING", "FLOAT", "COMBO"].includes(otherSlotType)) { + ctxSlotIndex = ctxSlots.findIndex((ctxSlot) => { + const ctxSlotName = ctxSlot.name.toUpperCase().replace("OPT_", "").replace("_NAME", ""); + let ctxSlotType = ctxSlot.type; + if (Array.isArray(ctxSlotType) || ctxSlotType.includes(",")) { + ctxSlotType = "COMBO"; + } + if (ctxSlotType !== otherSlotType) { + return false; + } + if (ctxSlotName === otherSlotName || + (ctxSlotName === "SEED" && otherSlotName.includes("SEED")) || + (ctxSlotName === "STEP_REFINER" && otherSlotName.includes("AT_STEP")) || + (ctxSlotName === "STEP_REFINER" && otherSlotName.includes("REFINER_STEP"))) { + return true; + } + if ((otherNodeType.includes("POSITIVE") || otherNodeName.includes("POSITIVE")) && + ((ctxSlotName === "POSITIVE" && otherSlotType === "CONDITIONING") || + (ctxSlotName === "TEXT_POS_G" && otherSlotName.includes("TEXT_G")) || + (ctxSlotName === "TEXT_POS_L" && otherSlotName.includes("TEXT_L")))) { + return true; + } + if ((otherNodeType.includes("NEGATIVE") || otherNodeName.includes("NEGATIVE")) && + ((ctxSlotName === "NEGATIVE" && otherSlotType === "CONDITIONING") || + (ctxSlotName === "TEXT_NEG_G" && otherSlotName.includes("TEXT_G")) || + (ctxSlotName === "TEXT_NEG_L" && otherSlotName.includes("TEXT_L")))) { + return true; + } + return false; + }); + } + else { + ctxSlotIndex = ctxSlots.map((s) => s.type).indexOf(otherSlotType); + } + return ctxSlotIndex; +} +export class BaseContextNode extends RgthreeBaseServerNode { + constructor(title) { + super(title); + this.___collapsed_width = 0; + } + get _collapsed_width() { + return this.___collapsed_width; + } + set _collapsed_width(width) { + const canvas = app.canvas; + const ctx = canvas.canvas.getContext("2d"); + const oldFont = ctx.font; + ctx.font = canvas.title_text_font; + let title = this.title.trim(); + this.___collapsed_width = 30 + (title ? 10 + ctx.measureText(title).width : 0); + ctx.font = oldFont; + } + connectByType(slot, sourceNode, sourceSlotType, optsIn) { + let canConnect = super.connectByType && + super.connectByType.call(this, slot, sourceNode, sourceSlotType, optsIn); + if (!super.connectByType) { + canConnect = LGraphNode.prototype.connectByType.call(this, slot, sourceNode, sourceSlotType, optsIn); + } + if (!canConnect && slot === 0) { + const ctrlKey = KEY_EVENT_SERVICE.ctrlKey; + for (const [index, input] of (sourceNode.inputs || []).entries()) { + if (input.link && !ctrlKey) { + continue; + } + const thisOutputSlot = findMatchingIndexByTypeOrName(sourceNode, input, this.outputs); + if (thisOutputSlot > -1) { + this.connect(thisOutputSlot, sourceNode, index); + } + } + } + return null; + } + connectByTypeOutput(slot, sourceNode, sourceSlotType, optsIn) { + var _a; + let canConnect = super.connectByTypeOutput && + super.connectByTypeOutput.call(this, slot, sourceNode, sourceSlotType, optsIn); + if (!super.connectByType) { + canConnect = LGraphNode.prototype.connectByTypeOutput.call(this, slot, sourceNode, sourceSlotType, optsIn); + } + if (!canConnect && slot === 0) { + const ctrlKey = KEY_EVENT_SERVICE.ctrlKey; + for (const [index, output] of (sourceNode.outputs || []).entries()) { + if (((_a = output.links) === null || _a === void 0 ? void 0 : _a.length) && !ctrlKey) { + continue; + } + const thisInputSlot = findMatchingIndexByTypeOrName(sourceNode, output, this.inputs); + if (thisInputSlot > -1) { + sourceNode.connect(index, this, thisInputSlot); + } + } + } + return null; + } + static setUp(comfyClass, nodeData, ctxClass) { + RgthreeBaseServerNode.registerForOverride(comfyClass, nodeData, ctxClass); + wait(500).then(() => { + LiteGraph.slot_types_default_out["RGTHREE_CONTEXT"] = + LiteGraph.slot_types_default_out["RGTHREE_CONTEXT"] || []; + LiteGraph.slot_types_default_out["RGTHREE_CONTEXT"].push(comfyClass.comfyClass); + }); + } + static onRegisteredForOverride(comfyClass, ctxClass) { + addConnectionLayoutSupport(ctxClass, app, [ + ["Left", "Right"], + ["Right", "Left"], + ]); + setTimeout(() => { + ctxClass.category = comfyClass.category; + }); + } +} +class ContextNode extends BaseContextNode { + constructor(title = ContextNode.title) { + super(title); + } + static setUp(comfyClass, nodeData) { + BaseContextNode.setUp(comfyClass, nodeData, ContextNode); + } + static onRegisteredForOverride(comfyClass, ctxClass) { + BaseContextNode.onRegisteredForOverride(comfyClass, ctxClass); + addMenuItem(ContextNode, app, { + name: "Convert To Context Big", + callback: (node) => { + replaceNode(node, ContextBigNode.type); + }, + }); + } +} +ContextNode.title = NodeTypesString.CONTEXT; +ContextNode.type = NodeTypesString.CONTEXT; +ContextNode.comfyClass = NodeTypesString.CONTEXT; +class ContextBigNode extends BaseContextNode { + constructor(title = ContextBigNode.title) { + super(title); + } + static setUp(comfyClass, nodeData) { + BaseContextNode.setUp(comfyClass, nodeData, ContextBigNode); + } + static onRegisteredForOverride(comfyClass, ctxClass) { + BaseContextNode.onRegisteredForOverride(comfyClass, ctxClass); + addMenuItem(ContextBigNode, app, { + name: "Convert To Context (Original)", + callback: (node) => { + replaceNode(node, ContextNode.type); + }, + }); + } +} +ContextBigNode.title = NodeTypesString.CONTEXT_BIG; +ContextBigNode.type = NodeTypesString.CONTEXT_BIG; +ContextBigNode.comfyClass = NodeTypesString.CONTEXT_BIG; +class BaseContextMultiCtxInputNode extends BaseContextNode { + constructor(title) { + super(title); + this.stabilizeBound = this.stabilize.bind(this); + this.addContextInput(5); + } + addContextInput(num = 1) { + for (let i = 0; i < num; i++) { + this.addInput(`ctx_${String(this.inputs.length + 1).padStart(2, "0")}`, "RGTHREE_CONTEXT"); + } + } + onConnectionsChange(type, slotIndex, isConnected, link, ioSlot) { + var _a; + (_a = super.onConnectionsChange) === null || _a === void 0 ? void 0 : _a.apply(this, [...arguments]); + if (type === LiteGraph.INPUT) { + this.scheduleStabilize(); + } + } + scheduleStabilize(ms = 64) { + return debounce(this.stabilizeBound, 64); + } + stabilize() { + removeUnusedInputsFromEnd(this, 4); + this.addContextInput(); + } +} +class ContextSwitchNode extends BaseContextMultiCtxInputNode { + constructor(title = ContextSwitchNode.title) { + super(title); + } + static setUp(comfyClass, nodeData) { + BaseContextNode.setUp(comfyClass, nodeData, ContextSwitchNode); + } + static onRegisteredForOverride(comfyClass, ctxClass) { + BaseContextNode.onRegisteredForOverride(comfyClass, ctxClass); + addMenuItem(ContextSwitchNode, app, { + name: "Convert To Context Switch Big", + callback: (node) => { + replaceNode(node, ContextSwitchBigNode.type); + }, + }); + } +} +ContextSwitchNode.title = NodeTypesString.CONTEXT_SWITCH; +ContextSwitchNode.type = NodeTypesString.CONTEXT_SWITCH; +ContextSwitchNode.comfyClass = NodeTypesString.CONTEXT_SWITCH; +class ContextSwitchBigNode extends BaseContextMultiCtxInputNode { + constructor(title = ContextSwitchBigNode.title) { + super(title); + } + static setUp(comfyClass, nodeData) { + BaseContextNode.setUp(comfyClass, nodeData, ContextSwitchBigNode); + } + static onRegisteredForOverride(comfyClass, ctxClass) { + BaseContextNode.onRegisteredForOverride(comfyClass, ctxClass); + addMenuItem(ContextSwitchBigNode, app, { + name: "Convert To Context Switch", + callback: (node) => { + replaceNode(node, ContextSwitchNode.type); + }, + }); + } +} +ContextSwitchBigNode.title = NodeTypesString.CONTEXT_SWITCH_BIG; +ContextSwitchBigNode.type = NodeTypesString.CONTEXT_SWITCH_BIG; +ContextSwitchBigNode.comfyClass = NodeTypesString.CONTEXT_SWITCH_BIG; +class ContextMergeNode extends BaseContextMultiCtxInputNode { + constructor(title = ContextMergeNode.title) { + super(title); + } + static setUp(comfyClass, nodeData) { + BaseContextNode.setUp(comfyClass, nodeData, ContextMergeNode); + } + static onRegisteredForOverride(comfyClass, ctxClass) { + BaseContextNode.onRegisteredForOverride(comfyClass, ctxClass); + addMenuItem(ContextMergeNode, app, { + name: "Convert To Context Merge Big", + callback: (node) => { + replaceNode(node, ContextMergeBigNode.type); + }, + }); + } +} +ContextMergeNode.title = NodeTypesString.CONTEXT_MERGE; +ContextMergeNode.type = NodeTypesString.CONTEXT_MERGE; +ContextMergeNode.comfyClass = NodeTypesString.CONTEXT_MERGE; +class ContextMergeBigNode extends BaseContextMultiCtxInputNode { + constructor(title = ContextMergeBigNode.title) { + super(title); + } + static setUp(comfyClass, nodeData) { + BaseContextNode.setUp(comfyClass, nodeData, ContextMergeBigNode); + } + static onRegisteredForOverride(comfyClass, ctxClass) { + BaseContextNode.onRegisteredForOverride(comfyClass, ctxClass); + addMenuItem(ContextMergeBigNode, app, { + name: "Convert To Context Switch", + callback: (node) => { + replaceNode(node, ContextMergeNode.type); + }, + }); + } +} +ContextMergeBigNode.title = NodeTypesString.CONTEXT_MERGE_BIG; +ContextMergeBigNode.type = NodeTypesString.CONTEXT_MERGE_BIG; +ContextMergeBigNode.comfyClass = NodeTypesString.CONTEXT_MERGE_BIG; +const contextNodes = [ + ContextNode, + ContextBigNode, + ContextSwitchNode, + ContextSwitchBigNode, + ContextMergeNode, + ContextMergeBigNode, +]; +const contextTypeToServerDef = {}; +function fixBadConfigs(node) { + const wrongName = node.outputs.find((o, i) => o.name === "CLIP_HEIGTH"); + if (wrongName) { + wrongName.name = "CLIP_HEIGHT"; + } +} +app.registerExtension({ + name: "rgthree.Context", + async beforeRegisterNodeDef(nodeType, nodeData) { + for (const ctxClass of contextNodes) { + if (nodeData.name === ctxClass.type) { + contextTypeToServerDef[ctxClass.type] = nodeData; + ctxClass.setUp(nodeType, nodeData); + break; + } + } + }, + async nodeCreated(node) { + const type = node.type || node.constructor.type; + const serverDef = type && contextTypeToServerDef[type]; + if (serverDef) { + fixBadConfigs(node); + matchLocalSlotsToServer(node, IoDirection.OUTPUT, serverDef); + if (!type.includes("Switch") && !type.includes("Merge")) { + matchLocalSlotsToServer(node, IoDirection.INPUT, serverDef); + } + } + }, + async loadedGraphNode(node) { + const type = node.type || node.constructor.type; + const serverDef = type && contextTypeToServerDef[type]; + if (serverDef) { + fixBadConfigs(node); + matchLocalSlotsToServer(node, IoDirection.OUTPUT, serverDef); + if (!type.includes("Switch") && !type.includes("Merge")) { + matchLocalSlotsToServer(node, IoDirection.INPUT, serverDef); + } + } + }, +}); diff --git a/rgthree-comfy/web/comfyui/dialog_info.js b/rgthree-comfy/web/comfyui/dialog_info.js new file mode 100644 index 0000000000000000000000000000000000000000..4d621a6ae1a239a2d40266e39b829b260cabd976 --- /dev/null +++ b/rgthree-comfy/web/comfyui/dialog_info.js @@ -0,0 +1,269 @@ +import { RgthreeDialog } from "../../rgthree/common/dialog.js"; +import { createElement as $el, empty, appendChildren, getClosestOrSelf, queryOne, query, setAttributes, } from "../../rgthree/common/utils_dom.js"; +import { logoCivitai, link, pencilColored, diskColored, dotdotdot, } from "../../rgthree/common/media/svgs.js"; +import { SERVICE as MODEL_INFO_SERVICE } from "../../rgthree/common/model_info_service.js"; +import { rgthree } from "./rgthree.js"; +import { MenuButton } from "../../rgthree/common/menu.js"; +import { generateId, injectCss } from "../../rgthree/common/shared_utils.js"; +export class RgthreeInfoDialog extends RgthreeDialog { + constructor(file) { + const dialogOptions = { + class: "rgthree-info-dialog", + title: `

Loading...

`, + content: "
Loading..
", + onBeforeClose: () => { + return true; + }, + }; + super(dialogOptions); + this.modifiedModelData = false; + this.modelInfo = null; + this.init(file); + } + async init(file) { + var _a, _b; + const cssPromise = injectCss("rgthree/common/css/dialog_model_info.css"); + this.modelInfo = await MODEL_INFO_SERVICE.getLora(file, false, false); + await cssPromise; + this.setContent(this.getInfoContent()); + this.setTitle(((_a = this.modelInfo) === null || _a === void 0 ? void 0 : _a["name"]) || ((_b = this.modelInfo) === null || _b === void 0 ? void 0 : _b["file"]) || "Unknown"); + this.attachEvents(); + } + getCloseEventDetail() { + const detail = { + dirty: this.modifiedModelData, + }; + return { detail }; + } + attachEvents() { + this.contentElement.addEventListener("click", async (e) => { + const target = getClosestOrSelf(e.target, "[data-action]"); + const action = target === null || target === void 0 ? void 0 : target.getAttribute("data-action"); + if (!target || !action) { + return; + } + await this.handleEventAction(action, target, e); + }); + } + async handleEventAction(action, target, e) { + var _a, _b; + const info = this.modelInfo; + if (!(info === null || info === void 0 ? void 0 : info.file)) { + return; + } + if (action === "fetch-civitai") { + this.modelInfo = await MODEL_INFO_SERVICE.refreshLora(info.file); + this.setContent(this.getInfoContent()); + this.setTitle(((_a = this.modelInfo) === null || _a === void 0 ? void 0 : _a["name"]) || ((_b = this.modelInfo) === null || _b === void 0 ? void 0 : _b["file"]) || "Unknown"); + } + else if (action === "copy-trained-words") { + const selected = query(".-rgthree-is-selected", target.closest("tr")); + const text = selected.map((el) => el.getAttribute("data-word")).join(", "); + await navigator.clipboard.writeText(text); + rgthree.showMessage({ + id: "copy-trained-words-" + generateId(4), + type: "success", + message: `Successfully copied ${selected.length} key word${selected.length === 1 ? "" : "s"}.`, + timeout: 4000, + }); + } + else if (action === "toggle-trained-word") { + target === null || target === void 0 ? void 0 : target.classList.toggle("-rgthree-is-selected"); + const tr = target.closest("tr"); + if (tr) { + const span = queryOne("td:first-child > *", tr); + let small = queryOne("small", span); + if (!small) { + small = $el("small", { parent: span }); + } + const num = query(".-rgthree-is-selected", tr).length; + small.innerHTML = num + ? `${num} selected | Copy` + : ""; + } + } + else if (action === "edit-row") { + const tr = target.closest("tr"); + const td = queryOne("td:nth-child(2)", tr); + const input = td.querySelector("input,textarea"); + if (!input) { + const fieldName = tr.dataset["fieldName"]; + tr.classList.add("-rgthree-editing"); + const isTextarea = fieldName === "userNote"; + const input = $el(`${isTextarea ? "textarea" : 'input[type="text"]'}`, { + value: td.textContent, + }); + input.addEventListener("keydown", (e) => { + if (!isTextarea && e.key === "Enter") { + const modified = saveEditableRow(info, tr, true); + this.modifiedModelData = this.modifiedModelData || modified; + e.stopPropagation(); + e.preventDefault(); + } + else if (e.key === "Escape") { + const modified = saveEditableRow(info, tr, false); + this.modifiedModelData = this.modifiedModelData || modified; + e.stopPropagation(); + e.preventDefault(); + } + }); + appendChildren(empty(td), [input]); + input.focus(); + } + else if (target.nodeName.toLowerCase() === "button") { + const modified = saveEditableRow(info, tr, true); + this.modifiedModelData = this.modifiedModelData || modified; + } + e === null || e === void 0 ? void 0 : e.preventDefault(); + e === null || e === void 0 ? void 0 : e.stopPropagation(); + } + } + getInfoContent() { + var _a, _b, _c, _d, _e, _f, _g, _h, _j, _k, _l, _m, _o, _p, _q, _r, _s, _t, _u, _v, _w, _x, _y; + const info = this.modelInfo || {}; + const civitaiLink = (_a = info.links) === null || _a === void 0 ? void 0 : _a.find((i) => i.includes("civitai.com/models")); + const html = ` +
    +
  • ${info.type || ""}
  • +
  • ${info.baseModel || ""}
  • +
  • + ${""} +
+ + + ${infoTableRow("File", info.file || "")} + ${infoTableRow("Hash (sha256)", info.sha256 || "")} + ${civitaiLink + ? infoTableRow("Civitai", `${logoCivitai}View on Civitai`) + : ((_c = (_b = info.raw) === null || _b === void 0 ? void 0 : _b.civitai) === null || _c === void 0 ? void 0 : _c.error) === "Model not found" + ? infoTableRow("Civitai", 'Model not found') + : ((_e = (_d = info.raw) === null || _d === void 0 ? void 0 : _d.civitai) === null || _e === void 0 ? void 0 : _e.error) + ? infoTableRow("Civitai", (_g = (_f = info.raw) === null || _f === void 0 ? void 0 : _f.civitai) === null || _g === void 0 ? void 0 : _g.error) + : !((_h = info.raw) === null || _h === void 0 ? void 0 : _h.civitai) + ? infoTableRow("Civitai", ``) + : ""} + + ${infoTableRow("Name", info.name || ((_k = (_j = info.raw) === null || _j === void 0 ? void 0 : _j.metadata) === null || _k === void 0 ? void 0 : _k.ss_output_name) || "", "The name for display.", "name")} + + ${!info.baseModelFile && !info.baseModelFile + ? "" + : infoTableRow("Base Model", (info.baseModel || "") + (info.baseModelFile ? ` (${info.baseModelFile})` : ""))} + + + ${!((_l = info.trainedWords) === null || _l === void 0 ? void 0 : _l.length) + ? "" + : infoTableRow("Trained Words", (_m = getTrainedWordsMarkup(info.trainedWords)) !== null && _m !== void 0 ? _m : "", "Trained words from the metadata and/or civitai. Click to select for copy.")} + + ${!((_p = (_o = info.raw) === null || _o === void 0 ? void 0 : _o.metadata) === null || _p === void 0 ? void 0 : _p.ss_clip_skip) || ((_r = (_q = info.raw) === null || _q === void 0 ? void 0 : _q.metadata) === null || _r === void 0 ? void 0 : _r.ss_clip_skip) == "None" + ? "" + : infoTableRow("Clip Skip", (_t = (_s = info.raw) === null || _s === void 0 ? void 0 : _s.metadata) === null || _t === void 0 ? void 0 : _t.ss_clip_skip)} + ${infoTableRow("Strength Min", (_u = info.strengthMin) !== null && _u !== void 0 ? _u : "", "The recommended minimum strength, In the Power Lora Loader node, strength will signal when it is below this threshold.", "strengthMin")} + ${infoTableRow("Strength Max", (_v = info.strengthMax) !== null && _v !== void 0 ? _v : "", "The recommended maximum strength. In the Power Lora Loader node, strength will signal when it is above this threshold.", "strengthMax")} + ${""} + ${infoTableRow("Additional Notes", (_w = info.userNote) !== null && _w !== void 0 ? _w : "", "Additional notes you'd like to keep and reference in the info dialog.", "userNote")} + +
+ +
    ${(_y = (_x = info.images) === null || _x === void 0 ? void 0 : _x.map((img) => ` +
  • +
    + +
    ${imgInfoField("", img.civitaiUrl + ? `civitai${link}` + : undefined)}${imgInfoField("seed", img.seed)}${imgInfoField("steps", img.steps)}${imgInfoField("cfg", img.cfg)}${imgInfoField("sampler", img.sampler)}${imgInfoField("model", img.model)}${imgInfoField("positive", img.positive)}${imgInfoField("negative", img.negative)}
    +
    +
  • `).join("")) !== null && _y !== void 0 ? _y : ""}
+ `; + const div = $el("div", { html }); + if (rgthree.isDevMode()) { + setAttributes(queryOne('[stub="menu"]', div), { + children: [ + new MenuButton({ + icon: dotdotdot, + options: [ + { label: "More Actions", type: "title" }, + { + label: "Open API JSON", + callback: async (e) => { + var _a; + if ((_a = this.modelInfo) === null || _a === void 0 ? void 0 : _a.file) { + window.open(`rgthree/api/loras/info?file=${encodeURIComponent(this.modelInfo.file)}`); + } + }, + }, + { + label: "Clear all local info", + callback: async (e) => { + var _a, _b, _c; + if ((_a = this.modelInfo) === null || _a === void 0 ? void 0 : _a.file) { + this.modelInfo = await MODEL_INFO_SERVICE.clearLoraFetchedData(this.modelInfo.file); + this.setContent(this.getInfoContent()); + this.setTitle(((_b = this.modelInfo) === null || _b === void 0 ? void 0 : _b["name"]) || ((_c = this.modelInfo) === null || _c === void 0 ? void 0 : _c["file"]) || "Unknown"); + } + }, + }, + ], + }), + ], + }); + } + return div; + } +} +function infoTableRow(name, value, help = "", editableFieldName = "") { + return ` + + ${name} ${help ? `` : ""} + ${String(value).startsWith("<") ? value : `${value}`} + ${editableFieldName + ? `` + : ""} + `; +} +function getTrainedWordsMarkup(words) { + let markup = `
    `; + for (const wordData of words || []) { + markup += `
  • + ${wordData.word} + ${wordData.civitai ? logoCivitai : ""} + ${wordData.count != null ? `${wordData.count}` : ""} +
  • `; + } + markup += `
`; + return markup; +} +function saveEditableRow(info, tr, saving = true) { + var _a; + const fieldName = tr.dataset["fieldName"]; + const input = queryOne("input,textarea", tr); + let newValue = (_a = info[fieldName]) !== null && _a !== void 0 ? _a : ""; + let modified = false; + if (saving) { + newValue = input.value; + if (fieldName.startsWith("strength")) { + if (Number.isNaN(Number(newValue))) { + alert(`You must enter a number into the ${fieldName} field.`); + return false; + } + newValue = (Math.round(Number(newValue) * 100) / 100).toFixed(2); + } + MODEL_INFO_SERVICE.saveLoraPartial(info.file, { [fieldName]: newValue }); + modified = true; + } + tr.classList.remove("-rgthree-editing"); + const td = queryOne("td:nth-child(2)", tr); + appendChildren(empty(td), [$el("span", { text: newValue })]); + return modified; +} +function imgInfoField(label, value) { + return value != null ? `${label ? `` : ""}${value}` : ""; +} diff --git a/rgthree-comfy/web/comfyui/display_any.js b/rgthree-comfy/web/comfyui/display_any.js new file mode 100644 index 0000000000000000000000000000000000000000..cff00f9258cced5da324de9e18aa91de7d857038 --- /dev/null +++ b/rgthree-comfy/web/comfyui/display_any.js @@ -0,0 +1,35 @@ +import { app } from "../../scripts/app.js"; +import { ComfyWidgets } from "../../scripts/widgets.js"; +import { addConnectionLayoutSupport } from "./utils.js"; +import { rgthree } from "./rgthree.js"; +let hasShownAlertForUpdatingInt = false; +app.registerExtension({ + name: "rgthree.DisplayAny", + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (nodeData.name === "Display Any (rgthree)" || nodeData.name === "Display Int (rgthree)") { + const onNodeCreated = nodeType.prototype.onNodeCreated; + nodeType.prototype.onNodeCreated = function () { + onNodeCreated ? onNodeCreated.apply(this, []) : undefined; + this.showValueWidget = ComfyWidgets["STRING"](this, "output", ["STRING", { multiline: true }], app).widget; + this.showValueWidget.inputEl.readOnly = true; + this.showValueWidget.serializeValue = async (node, index) => { + const n = rgthree.getNodeFromInitialGraphToPromptSerializedWorkflowBecauseComfyUIBrokeStuff(node); + if (n) { + n.widgets_values[index] = ""; + } + else { + console.warn("No serialized node found in workflow. May be attributed to " + + "https://github.com/comfyanonymous/ComfyUI/issues/2193"); + } + return ""; + }; + }; + addConnectionLayoutSupport(nodeType, app, [["Left"], ["Right"]]); + const onExecuted = nodeType.prototype.onExecuted; + nodeType.prototype.onExecuted = function (message) { + onExecuted === null || onExecuted === void 0 ? void 0 : onExecuted.apply(this, [message]); + this.showValueWidget.value = message.text[0]; + }; + } + }, +}); diff --git a/rgthree-comfy/web/comfyui/dynamic_context.js b/rgthree-comfy/web/comfyui/dynamic_context.js new file mode 100644 index 0000000000000000000000000000000000000000..5a2bc16a8916e0761e36690c4b7a342d797640e9 --- /dev/null +++ b/rgthree-comfy/web/comfyui/dynamic_context.js @@ -0,0 +1,253 @@ +import { app } from "../../scripts/app.js"; +import { IoDirection, followConnectionUntilType, getConnectedInputInfosAndFilterPassThroughs, } from "./utils.js"; +import { rgthree } from "./rgthree.js"; +import { SERVICE as CONTEXT_SERVICE, InputMutationOperation, } from "./services/context_service.js"; +import { NodeTypesString } from "./constants.js"; +import { removeUnusedInputsFromEnd } from "./utils_inputs_outputs.js"; +import { DynamicContextNodeBase } from "./dynamic_context_base.js"; +import { SERVICE as CONFIG_SERVICE } from "./services/config_service.js"; +const OWNED_PREFIX = "+"; +const REGEX_OWNED_PREFIX = /^\+\s*/; +const REGEX_EMPTY_INPUT = /^\+\s*$/; +export class DynamicContextNode extends DynamicContextNodeBase { + constructor(title = DynamicContextNode.title) { + super(title); + } + onNodeCreated() { + this.addInput("base_ctx", "RGTHREE_DYNAMIC_CONTEXT"); + this.ensureOneRemainingNewInputSlot(); + super.onNodeCreated(); + } + onConnectionsChange(type, slotIndex, isConnected, link, ioSlot) { + var _a; + (_a = super.onConnectionsChange) === null || _a === void 0 ? void 0 : _a.call(this, type, slotIndex, isConnected, link, ioSlot); + if (this.configuring) { + return; + } + if (type === LiteGraph.INPUT) { + if (isConnected) { + this.handleInputConnected(slotIndex); + } + else { + this.handleInputDisconnected(slotIndex); + } + } + } + onConnectInput(inputIndex, outputType, outputSlot, outputNode, outputIndex) { + var _a; + let canConnect = true; + if (super.onConnectInput) { + canConnect = super.onConnectInput.apply(this, [...arguments]); + } + if (canConnect && + outputNode instanceof DynamicContextNode && + outputIndex === 0 && + inputIndex !== 0) { + const [n, v] = rgthree.logger.warnParts("Currently, you can only connect a context node in the first slot."); + (_a = console[n]) === null || _a === void 0 ? void 0 : _a.call(console, ...v); + canConnect = false; + } + return canConnect; + } + handleInputConnected(slotIndex) { + const ioSlot = this.inputs[slotIndex]; + const connectedIndexes = []; + if (slotIndex === 0) { + let baseNodeInfos = getConnectedInputInfosAndFilterPassThroughs(this, this, 0); + const baseNodes = baseNodeInfos.map((n) => n.node); + const baseNodesDynamicCtx = baseNodes[0]; + if (baseNodesDynamicCtx === null || baseNodesDynamicCtx === void 0 ? void 0 : baseNodesDynamicCtx.provideInputsData) { + const inputsData = CONTEXT_SERVICE.getDynamicContextInputsData(baseNodesDynamicCtx); + console.log("inputsData", inputsData); + for (const input of baseNodesDynamicCtx.provideInputsData()) { + if (input.name === "base_ctx" || input.name === "+") { + continue; + } + this.addContextInput(input.name, input.type, input.index); + this.stabilizeNames(); + } + } + } + else if (this.isInputSlotForNewInput(slotIndex)) { + this.handleNewInputConnected(slotIndex); + } + } + isInputSlotForNewInput(slotIndex) { + const ioSlot = this.inputs[slotIndex]; + return ioSlot && ioSlot.name === "+" && ioSlot.type === "*"; + } + handleNewInputConnected(slotIndex) { + if (!this.isInputSlotForNewInput(slotIndex)) { + throw new Error('Expected the incoming slot index to be the "new input" input.'); + } + const ioSlot = this.inputs[slotIndex]; + let cxn = null; + if (ioSlot.link != null) { + cxn = followConnectionUntilType(this, IoDirection.INPUT, slotIndex, true); + } + if ((cxn === null || cxn === void 0 ? void 0 : cxn.type) && (cxn === null || cxn === void 0 ? void 0 : cxn.name)) { + let name = this.addOwnedPrefix(this.getNextUniqueNameForThisNode(cxn.name)); + if (name.match(/^\+\s*[A-Z_]+(\.\d+)?$/)) { + name = name.toLowerCase(); + } + ioSlot.name = name; + ioSlot.type = cxn.type; + ioSlot.removable = true; + while (!this.outputs[slotIndex]) { + this.addOutput("*", "*"); + } + this.outputs[slotIndex].type = cxn.type; + this.outputs[slotIndex].name = this.stripOwnedPrefix(name).toLocaleUpperCase(); + if (cxn.type === "COMBO" || cxn.type.includes(",") || Array.isArray(cxn.type)) { + this.outputs[slotIndex].widget = true; + } + this.inputsMutated({ + operation: InputMutationOperation.ADDED, + node: this, + slotIndex, + slot: ioSlot, + }); + this.stabilizeNames(); + this.ensureOneRemainingNewInputSlot(); + } + } + handleInputDisconnected(slotIndex) { + var _a, _b; + const inputs = this.getContextInputsList(); + if (slotIndex === 0) { + for (let index = inputs.length - 1; index > 0; index--) { + if (index === 0 || index === inputs.length - 1) { + continue; + } + const input = inputs[index]; + if (!this.isOwnedInput(input.name)) { + if (input.link || ((_b = (_a = this.outputs[index]) === null || _a === void 0 ? void 0 : _a.links) === null || _b === void 0 ? void 0 : _b.length)) { + this.renameContextInput(index, input.name, true); + } + else { + this.removeContextInput(index); + } + } + } + this.setSize(this.computeSize()); + this.setDirtyCanvas(true, true); + } + } + ensureOneRemainingNewInputSlot() { + removeUnusedInputsFromEnd(this, 1, REGEX_EMPTY_INPUT); + this.addInput(OWNED_PREFIX, "*"); + } + getNextUniqueNameForThisNode(desiredName) { + const inputs = this.getContextInputsList(); + const allExistingKeys = inputs.map((i) => this.stripOwnedPrefix(i.name).toLocaleUpperCase()); + desiredName = this.stripOwnedPrefix(desiredName); + let newName = desiredName; + let n = 0; + while (allExistingKeys.includes(newName.toLocaleUpperCase())) { + newName = `${desiredName}.${++n}`; + } + return newName; + } + removeInput(slotIndex) { + const slot = this.inputs[slotIndex]; + super.removeInput(slotIndex); + if (this.outputs[slotIndex]) { + this.removeOutput(slotIndex); + } + this.inputsMutated({ operation: InputMutationOperation.REMOVED, node: this, slotIndex, slot }); + this.stabilizeNames(); + } + stabilizeNames() { + const inputs = this.getContextInputsList(); + const names = []; + for (const [index, input] of inputs.entries()) { + if (index === 0 || index === inputs.length - 1) { + continue; + } + input.label = undefined; + this.outputs[index].label = undefined; + let origName = this.stripOwnedPrefix(input.name).replace(/\.\d+$/, ""); + let name = input.name; + if (!this.isOwnedInput(name)) { + names.push(name.toLocaleUpperCase()); + } + else { + let n = 0; + name = this.addOwnedPrefix(origName); + while (names.includes(this.stripOwnedPrefix(name).toLocaleUpperCase())) { + name = `${this.addOwnedPrefix(origName)}.${++n}`; + } + names.push(this.stripOwnedPrefix(name).toLocaleUpperCase()); + if (input.name !== name) { + this.renameContextInput(index, name); + } + } + } + } + getSlotMenuOptions(slot) { + const editable = this.isOwnedInput(slot.input.name) && this.type !== "*"; + return [ + { + content: "✏️ Rename Input", + disabled: !editable, + callback: () => { + var dialog = app.canvas.createDialog("Name", {}); + var dialogInput = dialog.querySelector("input"); + if (dialogInput) { + dialogInput.value = this.stripOwnedPrefix(slot.input.name || ""); + } + var inner = () => { + this.handleContextMenuRenameInputDialog(slot.slot, dialogInput.value); + dialog.close(); + }; + dialog.querySelector("button").addEventListener("click", inner); + dialogInput.addEventListener("keydown", (e) => { + var _a; + dialog.is_modified = true; + if (e.keyCode == 27) { + dialog.close(); + } + else if (e.keyCode == 13) { + inner(); + } + else if (e.keyCode != 13 && ((_a = e.target) === null || _a === void 0 ? void 0 : _a.localName) != "textarea") { + return; + } + e.preventDefault(); + e.stopPropagation(); + }); + dialogInput.focus(); + }, + }, + { + content: "🗑️ Delete Input", + disabled: !editable, + callback: () => { + this.removeInput(slot.slot); + }, + }, + ]; + } + handleContextMenuRenameInputDialog(slotIndex, value) { + app.graph.beforeChange(); + this.renameContextInput(slotIndex, value); + this.stabilizeNames(); + this.setDirtyCanvas(true, true); + app.graph.afterChange(); + } +} +DynamicContextNode.title = NodeTypesString.DYNAMIC_CONTEXT; +DynamicContextNode.type = NodeTypesString.DYNAMIC_CONTEXT; +DynamicContextNode.comfyClass = NodeTypesString.DYNAMIC_CONTEXT; +const contextDynamicNodes = [DynamicContextNode]; +app.registerExtension({ + name: "rgthree.DynamicContext", + async beforeRegisterNodeDef(nodeType, nodeData) { + if (!CONFIG_SERVICE.getConfigValue("unreleased.dynamic_context.enabled")) { + return; + } + if (nodeData.name === DynamicContextNode.type) { + DynamicContextNode.setUp(nodeType, nodeData); + } + }, +}); diff --git a/rgthree-comfy/web/comfyui/dynamic_context_base.js b/rgthree-comfy/web/comfyui/dynamic_context_base.js new file mode 100644 index 0000000000000000000000000000000000000000..864aca997ac27c321db8e1a85935e02efb1e8c2d --- /dev/null +++ b/rgthree-comfy/web/comfyui/dynamic_context_base.js @@ -0,0 +1,189 @@ +import { BaseContextNode } from "./context.js"; +import { RgthreeBaseServerNode } from "./base_node.js"; +import { moveArrayItem, wait } from "../../rgthree/common/shared_utils.js"; +import { RgthreeInvisibleWidget } from "./utils_widgets.js"; +import { getContextOutputName, InputMutationOperation, } from "./services/context_service.js"; +import { app } from "../../scripts/app.js"; +import { SERVICE as CONTEXT_SERVICE } from "./services/context_service.js"; +const OWNED_PREFIX = "+"; +const REGEX_OWNED_PREFIX = /^\+\s*/; +const REGEX_EMPTY_INPUT = /^\+\s*$/; +export class DynamicContextNodeBase extends BaseContextNode { + constructor() { + super(...arguments); + this.hasShadowInputs = false; + } + getContextInputsList() { + return this.inputs; + } + provideInputsData() { + const inputs = this.getContextInputsList(); + return inputs + .map((input, index) => ({ + name: this.stripOwnedPrefix(input.name), + type: String(input.type), + index, + })) + .filter((i) => i.type !== "*"); + } + addOwnedPrefix(name) { + return `+ ${this.stripOwnedPrefix(name)}`; + } + isOwnedInput(inputOrName) { + const name = typeof inputOrName == "string" ? inputOrName : (inputOrName === null || inputOrName === void 0 ? void 0 : inputOrName.name) || ""; + return REGEX_OWNED_PREFIX.test(name); + } + stripOwnedPrefix(name) { + return name.replace(REGEX_OWNED_PREFIX, ""); + } + handleUpstreamMutation(mutation) { + console.log(`[node ${this.id}] handleUpstreamMutation`, mutation); + if (mutation.operation === InputMutationOperation.ADDED) { + const slot = mutation.slot; + if (!slot) { + throw new Error("Cannot have an ADDED mutation without a provided slot data."); + } + this.addContextInput(this.stripOwnedPrefix(slot.name), slot.type, mutation.slotIndex); + return; + } + if (mutation.operation === InputMutationOperation.REMOVED) { + const slot = mutation.slot; + if (!slot) { + throw new Error("Cannot have an REMOVED mutation without a provided slot data."); + } + this.removeContextInput(mutation.slotIndex); + return; + } + if (mutation.operation === InputMutationOperation.RENAMED) { + const slot = mutation.slot; + if (!slot) { + throw new Error("Cannot have an RENAMED mutation without a provided slot data."); + } + this.renameContextInput(mutation.slotIndex, slot.name); + return; + } + } + clone() { + const cloned = super.clone(); + while (cloned.inputs.length > 1) { + cloned.removeInput(cloned.inputs.length - 1); + } + while (cloned.widgets.length > 1) { + cloned.removeWidget(cloned.widgets.length - 1); + } + while (cloned.outputs.length > 1) { + cloned.removeOutput(cloned.outputs.length - 1); + } + return cloned; + } + onNodeCreated() { + const node = this; + this.addCustomWidget(new RgthreeInvisibleWidget("output_keys", "RGTHREE_DYNAMIC_CONTEXT_OUTPUTS", "", () => { + return (node.outputs || []) + .map((o, i) => i > 0 && o.name) + .filter((n) => n !== false) + .join(","); + })); + } + addContextInput(name, type, slot = -1) { + const inputs = this.getContextInputsList(); + if (this.hasShadowInputs) { + inputs.push({ name, type, link: null }); + } + else { + this.addInput(name, type); + } + if (slot > -1) { + moveArrayItem(inputs, inputs.length - 1, slot); + } + else { + slot = inputs.length - 1; + } + if (type !== "*") { + const output = this.addOutput(getContextOutputName(name), type); + if (type === "COMBO" || String(type).includes(",") || Array.isArray(type)) { + output.widget = true; + } + if (slot > -1) { + moveArrayItem(this.outputs, this.outputs.length - 1, slot); + } + } + this.fixInputsOutputsLinkSlots(); + this.inputsMutated({ + operation: InputMutationOperation.ADDED, + node: this, + slotIndex: slot, + slot: inputs[slot], + }); + } + removeContextInput(slotIndex) { + if (this.hasShadowInputs) { + const inputs = this.getContextInputsList(); + const input = inputs.splice(slotIndex, 1)[0]; + if (this.outputs[slotIndex]) { + this.removeOutput(slotIndex); + } + } + else { + this.removeInput(slotIndex); + } + } + renameContextInput(index, newName, forceOwnBool = null) { + const inputs = this.getContextInputsList(); + const input = inputs[index]; + const oldName = input.name; + newName = this.stripOwnedPrefix(newName.trim() || this.getSlotDefaultInputLabel(index)); + if (forceOwnBool === true || (this.isOwnedInput(oldName) && forceOwnBool !== false)) { + newName = this.addOwnedPrefix(newName); + } + if (oldName !== newName) { + input.name = newName; + input.removable = this.isOwnedInput(newName); + this.outputs[index].name = getContextOutputName(inputs[index].name); + this.inputsMutated({ + node: this, + operation: InputMutationOperation.RENAMED, + slotIndex: index, + slot: input, + }); + } + } + getSlotDefaultInputLabel(slotIndex) { + const inputs = this.getContextInputsList(); + const input = inputs[slotIndex]; + let defaultLabel = this.stripOwnedPrefix(input.name).toLowerCase(); + return defaultLabel.toLocaleLowerCase(); + } + inputsMutated(mutation) { + CONTEXT_SERVICE.onInputChanges(this, mutation); + } + fixInputsOutputsLinkSlots() { + if (!this.hasShadowInputs) { + const inputs = this.getContextInputsList(); + for (let index = inputs.length - 1; index > 0; index--) { + const input = inputs[index]; + if ((input === null || input === void 0 ? void 0 : input.link) != null) { + app.graph.links[input.link].target_slot = index; + } + } + } + const outputs = this.outputs; + for (let index = outputs.length - 1; index > 0; index--) { + const output = outputs[index]; + if (output) { + output.nameLocked = true; + for (const link of output.links || []) { + app.graph.links[link].origin_slot = index; + } + } + } + } + static setUp(comfyClass, nodeData) { + RgthreeBaseServerNode.registerForOverride(comfyClass, nodeData, this); + wait(500).then(() => { + LiteGraph.slot_types_default_out["RGTHREE_DYNAMIC_CONTEXT"] = + LiteGraph.slot_types_default_out["RGTHREE_DYNAMIC_CONTEXT"] || []; + LiteGraph.slot_types_default_out["RGTHREE_DYNAMIC_CONTEXT"].push(comfyClass.comfyClass); + }); + } +} diff --git a/rgthree-comfy/web/comfyui/dynamic_context_switch.js b/rgthree-comfy/web/comfyui/dynamic_context_switch.js new file mode 100644 index 0000000000000000000000000000000000000000..9e5ebcd662cb7bc119284dbe24a717f394371b9c --- /dev/null +++ b/rgthree-comfy/web/comfyui/dynamic_context_switch.js @@ -0,0 +1,146 @@ +import { app } from "../../scripts/app.js"; +import { DynamicContextNodeBase } from "./dynamic_context_base.js"; +import { NodeTypesString } from "./constants.js"; +import { SERVICE as CONTEXT_SERVICE, getContextOutputName, } from "./services/context_service.js"; +import { getConnectedInputNodesAndFilterPassThroughs } from "./utils.js"; +import { debounce, moveArrayItem } from "../../rgthree/common/shared_utils.js"; +import { measureText } from "./utils_canvas.js"; +import { SERVICE as CONFIG_SERVICE } from "./services/config_service.js"; +class DynamicContextSwitchNode extends DynamicContextNodeBase { + constructor(title = DynamicContextSwitchNode.title) { + super(title); + this.hasShadowInputs = true; + this.lastInputsList = []; + this.shadowInputs = [ + { name: "base_ctx", type: "RGTHREE_DYNAMIC_CONTEXT", link: null, count: 0 }, + ]; + } + getContextInputsList() { + return this.shadowInputs; + } + handleUpstreamMutation(mutation) { + this.scheduleHardRefresh(); + } + onConnectionsChange(type, slotIndex, isConnected, link, ioSlot) { + var _a; + (_a = super.onConnectionsChange) === null || _a === void 0 ? void 0 : _a.call(this, type, slotIndex, isConnected, link, ioSlot); + if (this.configuring) { + return; + } + if (type === LiteGraph.INPUT) { + this.scheduleHardRefresh(); + } + } + scheduleHardRefresh(ms = 64) { + return debounce(() => { + this.refreshInputsAndOutputs(); + }, ms); + } + onNodeCreated() { + this.addInput("ctx_1", "RGTHREE_DYNAMIC_CONTEXT"); + this.addInput("ctx_2", "RGTHREE_DYNAMIC_CONTEXT"); + this.addInput("ctx_3", "RGTHREE_DYNAMIC_CONTEXT"); + this.addInput("ctx_4", "RGTHREE_DYNAMIC_CONTEXT"); + this.addInput("ctx_5", "RGTHREE_DYNAMIC_CONTEXT"); + super.onNodeCreated(); + } + addContextInput(name, type, slot) { } + refreshInputsAndOutputs() { + var _a; + const inputs = [ + { name: "base_ctx", type: "RGTHREE_DYNAMIC_CONTEXT", link: null, count: 0 }, + ]; + let numConnected = 0; + for (let i = 0; i < this.inputs.length; i++) { + const childCtxs = getConnectedInputNodesAndFilterPassThroughs(this, this, i); + if (childCtxs.length > 1) { + throw new Error("How is there more than one input?"); + } + const ctx = childCtxs[0]; + if (!ctx) + continue; + numConnected++; + const slotsData = CONTEXT_SERVICE.getDynamicContextInputsData(ctx); + console.log(slotsData); + for (const slotData of slotsData) { + const found = inputs.find((n) => getContextOutputName(slotData.name) === getContextOutputName(n.name)); + if (found) { + found.count += 1; + continue; + } + inputs.push({ + name: slotData.name, + type: slotData.type, + link: null, + count: 1, + }); + } + } + this.shadowInputs = inputs; + let i = 0; + for (i; i < this.shadowInputs.length; i++) { + const data = this.shadowInputs[i]; + let existing = this.outputs.find((o) => getContextOutputName(o.name) === getContextOutputName(data.name)); + if (!existing) { + existing = this.addOutput(getContextOutputName(data.name), data.type); + } + moveArrayItem(this.outputs, existing, i); + delete existing.rgthree_status; + if (data.count !== numConnected) { + existing.rgthree_status = "WARN"; + } + } + while (this.outputs[i]) { + const output = this.outputs[i]; + if ((_a = output === null || output === void 0 ? void 0 : output.links) === null || _a === void 0 ? void 0 : _a.length) { + output.rgthree_status = "ERROR"; + i++; + } + else { + this.removeOutput(i); + } + } + this.fixInputsOutputsLinkSlots(); + } + onDrawForeground(ctx, canvas) { + var _a, _b; + const low_quality = ((_b = (_a = canvas === null || canvas === void 0 ? void 0 : canvas.ds) === null || _a === void 0 ? void 0 : _a.scale) !== null && _b !== void 0 ? _b : 1) < 0.6; + if (low_quality || this.size[0] <= 10) { + return; + } + let y = LiteGraph.NODE_SLOT_HEIGHT - 1; + const w = this.size[0]; + ctx.save(); + ctx.font = "normal " + LiteGraph.NODE_SUBTEXT_SIZE + "px Arial"; + ctx.textAlign = "right"; + for (const output of this.outputs) { + if (!output.rgthree_status) { + y += LiteGraph.NODE_SLOT_HEIGHT; + continue; + } + const x = w - 20 - measureText(ctx, output.name); + if (output.rgthree_status === "ERROR") { + ctx.fillText("🛑", x, y); + } + else if (output.rgthree_status === "WARN") { + ctx.fillText("⚠️", x, y); + } + y += LiteGraph.NODE_SLOT_HEIGHT; + } + ctx.restore(); + } +} +DynamicContextSwitchNode.title = NodeTypesString.DYNAMIC_CONTEXT_SWITCH; +DynamicContextSwitchNode.type = NodeTypesString.DYNAMIC_CONTEXT_SWITCH; +DynamicContextSwitchNode.comfyClass = NodeTypesString.DYNAMIC_CONTEXT_SWITCH; +app.registerExtension({ + name: "rgthree.DynamicContextSwitch", + async beforeRegisterNodeDef(nodeType, nodeData) { + if (!CONFIG_SERVICE.getConfigValue("unreleased.dynamic_context.enabled")) { + return; + } + if (nodeData.name === DynamicContextSwitchNode.type) { + DynamicContextSwitchNode.setUp(nodeType, nodeData); + } + }, +}); diff --git a/rgthree-comfy/web/comfyui/fast_actions_button.js b/rgthree-comfy/web/comfyui/fast_actions_button.js new file mode 100644 index 0000000000000000000000000000000000000000..fa8344d9ef3a76018ddce4c513dafea1cc20d110 --- /dev/null +++ b/rgthree-comfy/web/comfyui/fast_actions_button.js @@ -0,0 +1,266 @@ +import { app } from "../../scripts/app.js"; +import { BaseAnyInputConnectedNode } from "./base_any_input_connected_node.js"; +import { NodeTypesString } from "./constants.js"; +import { addMenuItem } from "./utils.js"; +import { rgthree } from "./rgthree.js"; +const MODE_ALWAYS = 0; +const MODE_MUTE = 2; +const MODE_BYPASS = 4; +class FastActionsButton extends BaseAnyInputConnectedNode { + constructor(title) { + super(title); + this.comfyClass = NodeTypesString.FAST_ACTIONS_BUTTON; + this.logger = rgthree.newLogSession("[FastActionsButton]"); + this.isVirtualNode = true; + this.serialize_widgets = true; + this.widgetToData = new Map(); + this.nodeIdtoFunctionCache = new Map(); + this.executingFromShortcut = false; + this.properties["buttonText"] = "🎬 Action!"; + this.properties["shortcutModifier"] = "alt"; + this.properties["shortcutKey"] = ""; + this.buttonWidget = this.addWidget("button", this.properties["buttonText"], null, () => { + this.executeConnectedNodes(); + }, { serialize: false }); + this.keypressBound = this.onKeypress.bind(this); + this.keyupBound = this.onKeyup.bind(this); + this.onConstructed(); + } + configure(info) { + super.configure(info); + setTimeout(() => { + if (info.widgets_values) { + for (let [index, value] of info.widgets_values.entries()) { + if (index > 0) { + if (value.startsWith("comfy_action:")) { + value = value.replace("comfy_action:", ""); + this.addComfyActionWidget(index, value); + } + if (this.widgets[index]) { + this.widgets[index].value = value; + } + } + } + } + }, 100); + } + clone() { + const cloned = super.clone(); + cloned.properties["buttonText"] = "🎬 Action!"; + cloned.properties["shortcutKey"] = ""; + return cloned; + } + onAdded(graph) { + window.addEventListener("keydown", this.keypressBound); + window.addEventListener("keyup", this.keyupBound); + } + onRemoved() { + window.removeEventListener("keydown", this.keypressBound); + window.removeEventListener("keyup", this.keyupBound); + } + async onKeypress(event) { + const target = event.target; + if (this.executingFromShortcut || + target.localName == "input" || + target.localName == "textarea") { + return; + } + if (this.properties["shortcutKey"].trim() && + this.properties["shortcutKey"].toLowerCase() === event.key.toLowerCase()) { + const shortcutModifier = this.properties["shortcutModifier"]; + let good = shortcutModifier === "ctrl" && event.ctrlKey; + good = good || (shortcutModifier === "alt" && event.altKey); + good = good || (shortcutModifier === "shift" && event.shiftKey); + good = good || (shortcutModifier === "meta" && event.metaKey); + if (good) { + setTimeout(() => { + this.executeConnectedNodes(); + }, 20); + this.executingFromShortcut = true; + event.preventDefault(); + event.stopImmediatePropagation(); + app.canvas.dirty_canvas = true; + return false; + } + } + return; + } + onKeyup(event) { + const target = event.target; + if (target.localName == "input" || target.localName == "textarea") { + return; + } + this.executingFromShortcut = false; + } + onPropertyChanged(property, value, _prevValue) { + if (property == "buttonText") { + this.buttonWidget.name = value; + } + if (property == "shortcutKey") { + value = value.trim(); + this.properties["shortcutKey"] = (value && value[0].toLowerCase()) || ""; + } + } + handleLinkedNodesStabilization(linkedNodes) { + var _a, _b, _c, _d, _e, _f, _g, _h; + for (const [widget, data] of this.widgetToData.entries()) { + if (!data.node) { + continue; + } + if (!linkedNodes.includes(data.node)) { + const index = this.widgets.indexOf(widget); + if (index > -1) { + this.widgetToData.delete(widget); + this.removeWidget(widget); + } + else { + const [m, a] = this.logger.debugParts("Connected widget is not in widgets... weird."); + (_a = console[m]) === null || _a === void 0 ? void 0 : _a.call(console, ...a); + } + } + } + const badNodes = []; + let indexOffset = 1; + for (const [index, node] of linkedNodes.entries()) { + if (!node) { + const [m, a] = this.logger.debugParts("linkedNode provided that does not exist. "); + (_b = console[m]) === null || _b === void 0 ? void 0 : _b.call(console, ...a); + badNodes.push(node); + continue; + } + let widgetAtSlot = this.widgets[index + indexOffset]; + if (widgetAtSlot && ((_c = this.widgetToData.get(widgetAtSlot)) === null || _c === void 0 ? void 0 : _c.comfy)) { + indexOffset++; + widgetAtSlot = this.widgets[index + indexOffset]; + } + if (!widgetAtSlot || ((_e = (_d = this.widgetToData.get(widgetAtSlot)) === null || _d === void 0 ? void 0 : _d.node) === null || _e === void 0 ? void 0 : _e.id) !== node.id) { + let widget = null; + for (let i = index + indexOffset; i < this.widgets.length; i++) { + if (((_g = (_f = this.widgetToData.get(this.widgets[i])) === null || _f === void 0 ? void 0 : _f.node) === null || _g === void 0 ? void 0 : _g.id) === node.id) { + widget = this.widgets.splice(i, 1)[0]; + this.widgets.splice(index + indexOffset, 0, widget); + break; + } + } + if (!widget) { + const exposedActions = node.constructor.exposedActions || []; + widget = this.addWidget("combo", node.title, "None", "", { + values: ["None", "Mute", "Bypass", "Enable", ...exposedActions], + }); + widget.serializeValue = async (_node, _index) => { + return widget === null || widget === void 0 ? void 0 : widget.value; + }; + this.widgetToData.set(widget, { node }); + } + } + } + for (let i = this.widgets.length - 1; i > linkedNodes.length + indexOffset - 1; i--) { + const widgetAtSlot = this.widgets[i]; + if (widgetAtSlot && ((_h = this.widgetToData.get(widgetAtSlot)) === null || _h === void 0 ? void 0 : _h.comfy)) { + continue; + } + this.removeWidget(widgetAtSlot); + } + } + removeWidget(widgetOrSlot) { + const widget = typeof widgetOrSlot === "number" ? this.widgets[widgetOrSlot] : widgetOrSlot; + if (widget && this.widgetToData.has(widget)) { + this.widgetToData.delete(widget); + } + super.removeWidget(widgetOrSlot); + } + async executeConnectedNodes() { + var _a; + for (const widget of this.widgets) { + if (widget == this.buttonWidget) { + continue; + } + const action = widget.value; + const { comfy, node } = (_a = this.widgetToData.get(widget)) !== null && _a !== void 0 ? _a : {}; + if (comfy) { + if (action === "Queue Prompt") { + await comfy.queuePrompt(0); + } + continue; + } + if (node) { + if (action === "Mute") { + node.mode = MODE_MUTE; + } + else if (action === "Bypass") { + node.mode = MODE_BYPASS; + } + else if (action === "Enable") { + node.mode = MODE_ALWAYS; + } + if (node.handleAction) { + await node.handleAction(action); + } + app.graph.change(); + continue; + } + console.warn("Fast Actions Button has a widget without correct data."); + } + } + addComfyActionWidget(slot, value) { + let widget = this.addWidget("combo", "Comfy Action", "None", () => { + if (widget.value.startsWith("MOVE ")) { + this.widgets.push(this.widgets.splice(this.widgets.indexOf(widget), 1)[0]); + widget.value = widget["lastValue_"]; + } + else if (widget.value.startsWith("REMOVE ")) { + this.removeWidget(widget); + } + widget["lastValue_"] = widget.value; + }, { + values: ["None", "Queue Prompt", "REMOVE Comfy Action", "MOVE to end"], + }); + widget["lastValue_"] = value; + widget.serializeValue = async (_node, _index) => { + return `comfy_app:${widget === null || widget === void 0 ? void 0 : widget.value}`; + }; + this.widgetToData.set(widget, { comfy: app }); + if (slot != null) { + this.widgets.splice(slot, 0, this.widgets.splice(this.widgets.indexOf(widget), 1)[0]); + } + return widget; + } + onSerialize(o) { + var _a; + super.onSerialize && super.onSerialize(o); + for (let [index, value] of (o.widgets_values || []).entries()) { + if (((_a = this.widgets[index]) === null || _a === void 0 ? void 0 : _a.name) === "Comfy Action") { + o.widgets_values[index] = `comfy_action:${value}`; + } + } + } + static setUp() { + super.setUp(); + addMenuItem(this, app, { + name: "➕ Append a Comfy Action", + callback: (nodeArg) => { + nodeArg.addComfyActionWidget(); + }, + }); + } +} +FastActionsButton.type = NodeTypesString.FAST_ACTIONS_BUTTON; +FastActionsButton.title = NodeTypesString.FAST_ACTIONS_BUTTON; +FastActionsButton["@buttonText"] = { type: "string" }; +FastActionsButton["@shortcutModifier"] = { + type: "combo", + values: ["ctrl", "alt", "shift"], +}; +FastActionsButton["@shortcutKey"] = { type: "string" }; +FastActionsButton.collapsible = false; +app.registerExtension({ + name: "rgthree.FastActionsButton", + registerCustomNodes() { + FastActionsButton.setUp(); + }, + loadedGraphNode(node) { + if (node.type == FastActionsButton.title) { + node._tempWidth = node.size[0]; + } + }, +}); diff --git a/rgthree-comfy/web/comfyui/fast_groups_bypasser.js b/rgthree-comfy/web/comfyui/fast_groups_bypasser.js new file mode 100644 index 0000000000000000000000000000000000000000..80ee99108156aac3d96e66cbe91db224b5877c99 --- /dev/null +++ b/rgthree-comfy/web/comfyui/fast_groups_bypasser.js @@ -0,0 +1,27 @@ +import { app } from "../../scripts/app.js"; +import { NodeTypesString } from "./constants.js"; +import { BaseFastGroupsModeChanger } from "./fast_groups_muter.js"; +export class FastGroupsBypasser extends BaseFastGroupsModeChanger { + constructor(title = FastGroupsBypasser.title) { + super(title); + this.comfyClass = NodeTypesString.FAST_GROUPS_BYPASSER; + this.helpActions = "bypass and enable"; + this.modeOn = LiteGraph.ALWAYS; + this.modeOff = 4; + this.onConstructed(); + } +} +FastGroupsBypasser.type = NodeTypesString.FAST_GROUPS_BYPASSER; +FastGroupsBypasser.title = NodeTypesString.FAST_GROUPS_BYPASSER; +FastGroupsBypasser.exposedActions = ["Bypass all", "Enable all", "Toggle all"]; +app.registerExtension({ + name: "rgthree.FastGroupsBypasser", + registerCustomNodes() { + FastGroupsBypasser.setUp(); + }, + loadedGraphNode(node) { + if (node.type == FastGroupsBypasser.title) { + node.tempSize = [...node.size]; + } + }, +}); diff --git a/rgthree-comfy/web/comfyui/fast_groups_muter.js b/rgthree-comfy/web/comfyui/fast_groups_muter.js new file mode 100644 index 0000000000000000000000000000000000000000..a0000017f301340274da9f67a3caed3f1f287735 --- /dev/null +++ b/rgthree-comfy/web/comfyui/fast_groups_muter.js @@ -0,0 +1,418 @@ +import { app } from "../../scripts/app.js"; +import { RgthreeBaseVirtualNode } from "./base_node.js"; +import { NodeTypesString } from "./constants.js"; +import { SERVICE as FAST_GROUPS_SERVICE } from "./services/fast_groups_service.js"; +import { drawNodeWidget, fitString } from "./utils_canvas.js"; +const PROPERTY_SORT = "sort"; +const PROPERTY_SORT_CUSTOM_ALPHA = "customSortAlphabet"; +const PROPERTY_MATCH_COLORS = "matchColors"; +const PROPERTY_MATCH_TITLE = "matchTitle"; +const PROPERTY_SHOW_NAV = "showNav"; +const PROPERTY_RESTRICTION = "toggleRestriction"; +export class BaseFastGroupsModeChanger extends RgthreeBaseVirtualNode { + constructor(title = FastGroupsMuter.title) { + super(title); + this.modeOn = LiteGraph.ALWAYS; + this.modeOff = LiteGraph.NEVER; + this.debouncerTempWidth = 0; + this.tempSize = null; + this.serialize_widgets = false; + this.helpActions = "mute and unmute"; + this.properties[PROPERTY_MATCH_COLORS] = ""; + this.properties[PROPERTY_MATCH_TITLE] = ""; + this.properties[PROPERTY_SHOW_NAV] = true; + this.properties[PROPERTY_SORT] = "position"; + this.properties[PROPERTY_SORT_CUSTOM_ALPHA] = ""; + this.properties[PROPERTY_RESTRICTION] = "default"; + } + onConstructed() { + this.addOutput("OPT_CONNECTION", "*"); + return super.onConstructed(); + } + configure(info) { + var _a; + if ((_a = info.outputs) === null || _a === void 0 ? void 0 : _a.length) { + info.outputs.length = 1; + } + super.configure(info); + } + onAdded(graph) { + FAST_GROUPS_SERVICE.addFastGroupNode(this); + } + onRemoved() { + FAST_GROUPS_SERVICE.removeFastGroupNode(this); + } + refreshWidgets() { + var _a, _b, _c, _d, _e, _f, _g, _h; + const canvas = app.canvas; + let sort = ((_a = this.properties) === null || _a === void 0 ? void 0 : _a[PROPERTY_SORT]) || "position"; + let customAlphabet = null; + if (sort === "custom alphabet") { + const customAlphaStr = (_c = (_b = this.properties) === null || _b === void 0 ? void 0 : _b[PROPERTY_SORT_CUSTOM_ALPHA]) === null || _c === void 0 ? void 0 : _c.replace(/\n/g, ""); + if (customAlphaStr && customAlphaStr.trim()) { + customAlphabet = customAlphaStr.includes(",") + ? customAlphaStr.toLocaleLowerCase().split(",") + : customAlphaStr.toLocaleLowerCase().trim().split(""); + } + if (!(customAlphabet === null || customAlphabet === void 0 ? void 0 : customAlphabet.length)) { + sort = "alphanumeric"; + customAlphabet = null; + } + } + const groups = [...FAST_GROUPS_SERVICE.getGroups(sort)]; + if (customAlphabet === null || customAlphabet === void 0 ? void 0 : customAlphabet.length) { + groups.sort((a, b) => { + let aIndex = -1; + let bIndex = -1; + for (const [index, alpha] of customAlphabet.entries()) { + aIndex = + aIndex < 0 ? (a.title.toLocaleLowerCase().startsWith(alpha) ? index : -1) : aIndex; + bIndex = + bIndex < 0 ? (b.title.toLocaleLowerCase().startsWith(alpha) ? index : -1) : bIndex; + if (aIndex > -1 && bIndex > -1) { + break; + } + } + if (aIndex > -1 && bIndex > -1) { + const ret = aIndex - bIndex; + if (ret === 0) { + return a.title.localeCompare(b.title); + } + return ret; + } + else if (aIndex > -1) { + return -1; + } + else if (bIndex > -1) { + return 1; + } + return a.title.localeCompare(b.title); + }); + } + let filterColors = (((_e = (_d = this.properties) === null || _d === void 0 ? void 0 : _d[PROPERTY_MATCH_COLORS]) === null || _e === void 0 ? void 0 : _e.split(",")) || []).filter((c) => c.trim()); + if (filterColors.length) { + filterColors = filterColors.map((color) => { + color = color.trim().toLocaleLowerCase(); + if (LGraphCanvas.node_colors[color]) { + color = LGraphCanvas.node_colors[color].groupcolor; + } + color = color.replace("#", "").toLocaleLowerCase(); + if (color.length === 3) { + color = color.replace(/(.)(.)(.)/, "$1$1$2$2$3$3"); + } + return `#${color}`; + }); + } + let index = 0; + for (const group of groups) { + if (filterColors.length) { + let groupColor = (_f = group.color) === null || _f === void 0 ? void 0 : _f.replace("#", "").trim().toLocaleLowerCase(); + if (!groupColor) { + continue; + } + if (groupColor.length === 3) { + groupColor = groupColor.replace(/(.)(.)(.)/, "$1$1$2$2$3$3"); + } + groupColor = `#${groupColor}`; + if (!filterColors.includes(groupColor)) { + continue; + } + } + if ((_h = (_g = this.properties) === null || _g === void 0 ? void 0 : _g[PROPERTY_MATCH_TITLE]) === null || _h === void 0 ? void 0 : _h.trim()) { + try { + if (!new RegExp(this.properties[PROPERTY_MATCH_TITLE], "i").exec(group.title)) { + continue; + } + } + catch (e) { + console.error(e); + continue; + } + } + const widgetName = `Enable ${group.title}`; + let widget = this.widgets.find((w) => w.name === widgetName); + if (!widget) { + this.tempSize = [...this.size]; + widget = this.addCustomWidget({ + name: "RGTHREE_TOGGLE_AND_NAV", + label: "", + value: false, + disabled: false, + options: { on: "yes", off: "no" }, + draw: function (ctx, node, width, posY, height) { + var _a; + const widgetData = drawNodeWidget(ctx, { + width, + height, + posY, + }); + const showNav = ((_a = node.properties) === null || _a === void 0 ? void 0 : _a[PROPERTY_SHOW_NAV]) !== false; + let currentX = widgetData.width - widgetData.margin; + if (!widgetData.lowQuality && showNav) { + currentX -= 7; + const midY = widgetData.posY + widgetData.height * 0.5; + ctx.fillStyle = ctx.strokeStyle = "#89A"; + ctx.lineJoin = "round"; + ctx.lineCap = "round"; + const arrow = new Path2D(`M${currentX} ${midY} l -7 6 v -3 h -7 v -6 h 7 v -3 z`); + ctx.fill(arrow); + ctx.stroke(arrow); + currentX -= 14; + currentX -= 7; + ctx.strokeStyle = widgetData.colorOutline; + ctx.stroke(new Path2D(`M ${currentX} ${widgetData.posY} v ${widgetData.height}`)); + } + else if (widgetData.lowQuality && showNav) { + currentX -= 28; + } + currentX -= 7; + ctx.fillStyle = this.value ? "#89A" : "#333"; + ctx.beginPath(); + const toggleRadius = height * 0.36; + ctx.arc(currentX - toggleRadius, posY + height * 0.5, toggleRadius, 0, Math.PI * 2); + ctx.fill(); + currentX -= toggleRadius * 2; + if (!widgetData.lowQuality) { + currentX -= 4; + ctx.textAlign = "right"; + ctx.fillStyle = this.value ? widgetData.colorText : widgetData.colorTextSecondary; + const label = this.label || this.name; + const toggleLabelOn = this.options.on || "true"; + const toggleLabelOff = this.options.off || "false"; + ctx.fillText(this.value ? toggleLabelOn : toggleLabelOff, currentX, posY + height * 0.7); + currentX -= Math.max(ctx.measureText(toggleLabelOn).width, ctx.measureText(toggleLabelOff).width); + currentX -= 7; + ctx.textAlign = "left"; + let maxLabelWidth = widgetData.width - widgetData.margin - 10 - (widgetData.width - currentX); + if (label != null) { + ctx.fillText(fitString(ctx, label, maxLabelWidth), widgetData.margin + 10, posY + height * 0.7); + } + } + }, + serializeValue(serializedNode, widgetIndex) { + return this.value; + }, + mouse(event, pos, node) { + var _a, _b, _c; + if (event.type == "pointerdown") { + if (((_a = node.properties) === null || _a === void 0 ? void 0 : _a[PROPERTY_SHOW_NAV]) !== false && + pos[0] >= node.size[0] - 15 - 28 - 1) { + const canvas = app.canvas; + const lowQuality = (((_b = canvas.ds) === null || _b === void 0 ? void 0 : _b.scale) || 1) <= 0.5; + if (!lowQuality) { + canvas.centerOnNode(group); + const zoomCurrent = ((_c = canvas.ds) === null || _c === void 0 ? void 0 : _c.scale) || 1; + const zoomX = canvas.canvas.width / group._size[0] - 0.02; + const zoomY = canvas.canvas.height / group._size[1] - 0.02; + canvas.setZoom(Math.min(zoomCurrent, zoomX, zoomY), [ + canvas.canvas.width / 2, + canvas.canvas.height / 2, + ]); + canvas.setDirty(true, true); + } + } + else { + this.value = !this.value; + setTimeout(() => { + var _a; + (_a = this.callback) === null || _a === void 0 ? void 0 : _a.call(this, this.value, app.canvas, node, pos, event); + }, 20); + } + } + return true; + }, + }); + widget.doModeChange = (force, skipOtherNodeCheck) => { + var _a, _b, _c; + group.recomputeInsideNodes(); + const hasAnyActiveNodes = group._nodes.some((n) => n.mode === LiteGraph.ALWAYS); + let newValue = force != null ? force : !hasAnyActiveNodes; + if (skipOtherNodeCheck !== true) { + if (newValue && ((_b = (_a = this.properties) === null || _a === void 0 ? void 0 : _a[PROPERTY_RESTRICTION]) === null || _b === void 0 ? void 0 : _b.includes(" one"))) { + for (const widget of this.widgets) { + widget.doModeChange(false, true); + } + } + else if (!newValue && ((_c = this.properties) === null || _c === void 0 ? void 0 : _c[PROPERTY_RESTRICTION]) === "always one") { + newValue = this.widgets.every((w) => !w.value || w === widget); + } + } + for (const node of group._nodes) { + node.mode = (newValue ? this.modeOn : this.modeOff); + } + group._rgthreeHasAnyActiveNode = newValue; + widget.value = newValue; + app.graph.setDirtyCanvas(true, false); + }; + widget.callback = () => { + widget.doModeChange(); + }; + this.setSize(this.computeSize()); + } + if (widget.name != widgetName) { + widget.name = widgetName; + this.setDirtyCanvas(true, false); + } + if (widget.value != group._rgthreeHasAnyActiveNode) { + widget.value = group._rgthreeHasAnyActiveNode; + this.setDirtyCanvas(true, false); + } + if (this.widgets[index] !== widget) { + const oldIndex = this.widgets.findIndex((w) => w === widget); + this.widgets.splice(index, 0, this.widgets.splice(oldIndex, 1)[0]); + this.setDirtyCanvas(true, false); + } + index++; + } + while ((this.widgets || [])[index]) { + this.removeWidget(index++); + } + } + computeSize(out) { + let size = super.computeSize(out); + if (this.tempSize) { + size[0] = Math.max(this.tempSize[0], size[0]); + size[1] = Math.max(this.tempSize[1], size[1]); + this.debouncerTempWidth && clearTimeout(this.debouncerTempWidth); + this.debouncerTempWidth = setTimeout(() => { + this.tempSize = null; + }, 32); + } + setTimeout(() => { + app.graph.setDirtyCanvas(true, true); + }, 16); + return size; + } + async handleAction(action) { + var _a, _b, _c, _d, _e; + if (action === "Mute all" || action === "Bypass all") { + const alwaysOne = ((_a = this.properties) === null || _a === void 0 ? void 0 : _a[PROPERTY_RESTRICTION]) === "always one"; + for (const [index, widget] of this.widgets.entries()) { + widget === null || widget === void 0 ? void 0 : widget.doModeChange(alwaysOne && !index ? true : false, true); + } + } + else if (action === "Enable all") { + const onlyOne = (_b = this.properties) === null || _b === void 0 ? void 0 : _b[PROPERTY_RESTRICTION].includes(" one"); + for (const [index, widget] of this.widgets.entries()) { + widget === null || widget === void 0 ? void 0 : widget.doModeChange(onlyOne && index > 0 ? false : true, true); + } + } + else if (action === "Toggle all") { + const onlyOne = (_c = this.properties) === null || _c === void 0 ? void 0 : _c[PROPERTY_RESTRICTION].includes(" one"); + let foundOne = false; + for (const [index, widget] of this.widgets.entries()) { + let newValue = onlyOne && foundOne ? false : !widget.value; + foundOne = foundOne || newValue; + widget === null || widget === void 0 ? void 0 : widget.doModeChange(newValue, true); + } + if (!foundOne && ((_d = this.properties) === null || _d === void 0 ? void 0 : _d[PROPERTY_RESTRICTION]) === "always one") { + (_e = this.widgets[this.widgets.length - 1]) === null || _e === void 0 ? void 0 : _e.doModeChange(true, true); + } + } + } + getHelp() { + return ` +

The ${this.type.replace("(rgthree)", "")} is an input-less node that automatically collects all groups in your current + workflow and allows you to quickly ${this.helpActions} all nodes within the group.

+
    +
  • +

    + Properties. You can change the following properties (by right-clicking + on the node, and select "Properties" or "Properties Panel" from the menu): +

    +
      +
    • + ${PROPERTY_MATCH_COLORS} - Only add groups that match the provided + colors. Can be ComfyUI colors (red, pale_blue) or hex codes (#a4d399). Multiple can be + added, comma delimited. +

    • +
    • + ${PROPERTY_MATCH_TITLE} - Filter the list of toggles by title match + (string match, or regular expression). +

    • +
    • + ${PROPERTY_SHOW_NAV} - Add / remove a quick navigation arrow to take you + to the group. (default: true) +

    • +
    • + ${PROPERTY_SORT} - Sort the toggles' order by "alphanumeric", graph + "position", or "custom alphabet". (default: "position") +

    • +
    • +

      + ${PROPERTY_SORT_CUSTOM_ALPHA} - When the + ${PROPERTY_SORT} property is "custom alphabet" you can define the + alphabet to use here, which will match the beginning of each group name and + sort against it. If group titles do not match any custom alphabet entry, then they + will be put after groups that do, ordered alphanumerically. +

      +

      + This can be a list of single characters, like "zyxw..." or comma delimited strings + for more control, like "sdxl,pro,sd,n,p". +

      +

      + Note, when two group title match the same custom alphabet entry, the normal + alphanumeric alphabet breaks the tie. For instance, a custom alphabet of + "e,s,d" will order groups names like "SDXL, SEGS, Detailer" eventhough the custom + alphabet has an "e" before "d" (where one may expect "SE" to be before "SD"). +

      +

      + To have "SEGS" appear before "SDXL" you can use longer strings. For instance, the + custom alphabet value of "se,s,f" would work here. +

      +
    • +
    • + ${PROPERTY_RESTRICTION} - Optionally, attempt to restrict the number of + widgets that can be enabled to a maximum of one, or always one. +

      +

      Note: If using "max one" or "always one" then this is only + enforced when clicking a toggle on this node; if nodes within groups are changed + outside of the initial toggle click, then these restriction will not be enforced, and + could result in a state where more than one toggle is enabled. This could also happen + if nodes are overlapped with multiple groups. +

    • + +
    +
  • +
`; + } +} +BaseFastGroupsModeChanger.type = NodeTypesString.FAST_GROUPS_MUTER; +BaseFastGroupsModeChanger.title = NodeTypesString.FAST_GROUPS_MUTER; +BaseFastGroupsModeChanger.exposedActions = ["Mute all", "Enable all", "Toggle all"]; +BaseFastGroupsModeChanger["@matchColors"] = { type: "string" }; +BaseFastGroupsModeChanger["@matchTitle"] = { type: "string" }; +BaseFastGroupsModeChanger["@showNav"] = { type: "boolean" }; +BaseFastGroupsModeChanger["@sort"] = { + type: "combo", + values: ["position", "alphanumeric", "custom alphabet"], +}; +BaseFastGroupsModeChanger["@customSortAlphabet"] = { type: "string" }; +BaseFastGroupsModeChanger["@toggleRestriction"] = { + type: "combo", + values: ["default", "max one", "always one"], +}; +export class FastGroupsMuter extends BaseFastGroupsModeChanger { + constructor(title = FastGroupsMuter.title) { + super(title); + this.comfyClass = NodeTypesString.FAST_GROUPS_MUTER; + this.helpActions = "mute and unmute"; + this.modeOn = LiteGraph.ALWAYS; + this.modeOff = LiteGraph.NEVER; + this.onConstructed(); + } +} +FastGroupsMuter.type = NodeTypesString.FAST_GROUPS_MUTER; +FastGroupsMuter.title = NodeTypesString.FAST_GROUPS_MUTER; +FastGroupsMuter.exposedActions = ["Bypass all", "Enable all", "Toggle all"]; +app.registerExtension({ + name: "rgthree.FastGroupsMuter", + registerCustomNodes() { + FastGroupsMuter.setUp(); + }, + loadedGraphNode(node) { + if (node.type == FastGroupsMuter.title) { + node.tempSize = [...node.size]; + } + }, +}); diff --git a/rgthree-comfy/web/comfyui/feature_group_fast_toggle.js b/rgthree-comfy/web/comfyui/feature_group_fast_toggle.js new file mode 100644 index 0000000000000000000000000000000000000000..aee1fdbb2ae52f79cd15c90fae02573c5ea8ed02 --- /dev/null +++ b/rgthree-comfy/web/comfyui/feature_group_fast_toggle.js @@ -0,0 +1,169 @@ +import { app } from "../../scripts/app.js"; +import { rgthree } from "./rgthree.js"; +import { SERVICE as CONFIG_SERVICE } from "./services/config_service.js"; +const BTN_SIZE = 20; +const BTN_MARGIN = [6, 4]; +const BTN_SPACING = 8; +const BTN_GRID = BTN_SIZE / 8; +const TOGGLE_TO_MODE = new Map([ + ["MUTE", LiteGraph.NEVER], + ["BYPASS", 4], +]); +function clickedOnToggleButton(e, group) { + const toggles = CONFIG_SERVICE.getFeatureValue("group_header_fast_toggle.toggles"); + const pos = group.pos; + const size = group.size; + for (let i = 0; i < toggles.length; i++) { + const toggle = toggles[i]; + if (LiteGraph.isInsideRectangle(e.canvasX, e.canvasY, pos[0] + size[0] - (BTN_SIZE + BTN_MARGIN[0]) * (i + 1), pos[1] + BTN_MARGIN[1], BTN_SIZE, BTN_SIZE)) { + return toggle; + } + } + return null; +} +app.registerExtension({ + name: "rgthree.GroupHeaderToggles", + async setup() { + rgthree.addEventListener("on-process-mouse-down", ((e) => { + if (!CONFIG_SERVICE.getFeatureValue("group_header_fast_toggle.enabled")) + return; + const canvas = app.canvas; + if (canvas.selected_group) { + const originalEvent = e.detail.originalEvent; + const group = canvas.selected_group; + const clickedOnToggle = clickedOnToggleButton(originalEvent, group) || ""; + const toggleMode = TOGGLE_TO_MODE.get(clickedOnToggle === null || clickedOnToggle === void 0 ? void 0 : clickedOnToggle.toLocaleUpperCase()); + if (toggleMode) { + group.recomputeInsideNodes(); + const hasAnyActiveNodes = group._nodes.some((n) => n.mode === LiteGraph.ALWAYS); + const isAllMuted = !hasAnyActiveNodes && group._nodes.every((n) => n.mode === LiteGraph.NEVER); + const isAllBypassed = !hasAnyActiveNodes && !isAllMuted && group._nodes.every((n) => n.mode === 4); + let newMode = LiteGraph.ALWAYS; + if (toggleMode === LiteGraph.NEVER) { + newMode = isAllMuted ? LiteGraph.ALWAYS : LiteGraph.NEVER; + } + else { + newMode = isAllBypassed ? LiteGraph.ALWAYS : 4; + } + for (const node of group._nodes) { + node.mode = newMode; + } + canvas.selected_group = null; + canvas.dragging_canvas = false; + } + } + })); + const drawGroups = LGraphCanvas.prototype.drawGroups; + LGraphCanvas.prototype.drawGroups = function (canvasEl, ctx) { + drawGroups.apply(this, [...arguments]); + if (!CONFIG_SERVICE.getFeatureValue("group_header_fast_toggle.enabled") || + !rgthree.lastAdjustedMouseEvent) { + return; + } + const graph = app.graph; + let groups; + if (CONFIG_SERVICE.getFeatureValue("group_header_fast_toggle.show") !== "always") { + const hoverGroup = graph.getGroupOnPos(rgthree.lastAdjustedMouseEvent.canvasX, rgthree.lastAdjustedMouseEvent.canvasY); + groups = hoverGroup ? [hoverGroup] : []; + } + else { + groups = graph._groups || []; + } + if (!groups.length) { + return; + } + const toggles = CONFIG_SERVICE.getFeatureValue("group_header_fast_toggle.toggles"); + ctx.save(); + for (const group of groups || []) { + let anyActive = false; + let allMuted = !!group._nodes.length; + let allBypassed = allMuted; + for (const node of group._nodes) { + anyActive = anyActive || node.mode === LiteGraph.ALWAYS; + allMuted = allMuted && node.mode === LiteGraph.NEVER; + allBypassed = allBypassed && node.mode === 4; + if (anyActive || (!allMuted && !allBypassed)) { + break; + } + } + for (let i = 0; i < toggles.length; i++) { + const toggle = toggles[i]; + const on = toggle === "bypass" ? allBypassed : allMuted; + const pos = group._pos; + const size = group._size; + ctx.fillStyle = ctx.strokeStyle = group.color || "#335"; + const x = pos[0] + size[0] - BTN_MARGIN[0] - BTN_SIZE - (BTN_SPACING + BTN_SIZE) * i; + const y = pos[1] + BTN_MARGIN[1]; + const midX = x + BTN_SIZE / 2; + const midY = y + BTN_SIZE / 2; + ctx.beginPath(); + ctx.lineJoin = "round"; + ctx.rect(x, y, BTN_SIZE, BTN_SIZE); + ctx.lineWidth = 2; + if (toggle === "mute") { + ctx.lineJoin = "round"; + ctx.lineCap = "round"; + if (on) { + ctx.stroke(new Path2D(` + ${eyeFrame(midX, midY)} + ${eyeLashes(midX, midY)} + `)); + } + else { + const radius = BTN_GRID * 1.5; + ctx.fill(new Path2D(` + ${eyeFrame(midX, midY)} + ${eyeFrame(midX, midY, -1)} + ${circlePath(midX, midY, radius)} + ${circlePath(midX + BTN_GRID / 2, midY - BTN_GRID / 2, BTN_GRID * 0.375)} + `), "evenodd"); + ctx.stroke(new Path2D(`${eyeFrame(midX, midY)} ${eyeFrame(midX, midY, -1)}`)); + ctx.globalAlpha = this.editor_alpha * 0.5; + ctx.stroke(new Path2D(`${eyeLashes(midX, midY)} ${eyeLashes(midX, midY, -1)}`)); + ctx.globalAlpha = this.editor_alpha; + } + } + else { + const lineChanges = on + ? `a ${BTN_GRID * 3}, ${BTN_GRID * 3} 0 1, 1 ${BTN_GRID * 3 * 2},0 + l ${BTN_GRID * 2.0} 0` + : `l ${BTN_GRID * 8} 0`; + ctx.stroke(new Path2D(` + M ${x} ${midY} + ${lineChanges} + M ${x + BTN_SIZE} ${midY} l -2 2 + M ${x + BTN_SIZE} ${midY} l -2 -2 + `)); + ctx.fill(new Path2D(`${circlePath(x + BTN_GRID * 3, midY, BTN_GRID * 1.8)}`)); + } + } + } + ctx.restore(); + }; + }, +}); +function eyeFrame(midX, midY, yFlip = 1) { + return ` + M ${midX - BTN_SIZE / 2} ${midY} + c ${BTN_GRID * 1.5} ${yFlip * BTN_GRID * 2.5}, ${BTN_GRID * (8 - 1.5)} ${yFlip * BTN_GRID * 2.5}, ${BTN_GRID * 8} 0 + `; +} +function eyeLashes(midX, midY, yFlip = 1) { + return ` + M ${midX - BTN_GRID * 3.46} ${midY + yFlip * BTN_GRID * 0.9} l -1.15 ${1.25 * yFlip} + M ${midX - BTN_GRID * 2.38} ${midY + yFlip * BTN_GRID * 1.6} l -0.90 ${1.5 * yFlip} + M ${midX - BTN_GRID * 1.15} ${midY + yFlip * BTN_GRID * 1.95} l -0.50 ${1.75 * yFlip} + M ${midX + BTN_GRID * 0.0} ${midY + yFlip * BTN_GRID * 2.0} l 0.00 ${2.0 * yFlip} + M ${midX + BTN_GRID * 1.15} ${midY + yFlip * BTN_GRID * 1.95} l 0.50 ${1.75 * yFlip} + M ${midX + BTN_GRID * 2.38} ${midY + yFlip * BTN_GRID * 1.6} l 0.90 ${1.5 * yFlip} + M ${midX + BTN_GRID * 3.46} ${midY + yFlip * BTN_GRID * 0.9} l 1.15 ${1.25 * yFlip} +`; +} +function circlePath(cx, cy, radius) { + return ` + M ${cx} ${cy} + m ${radius}, 0 + a ${radius},${radius} 0 1, 1 -${radius * 2},0 + a ${radius},${radius} 0 1, 1 ${radius * 2},0 + `; +} diff --git a/rgthree-comfy/web/comfyui/feature_import_individual_nodes.js b/rgthree-comfy/web/comfyui/feature_import_individual_nodes.js new file mode 100644 index 0000000000000000000000000000000000000000..4e0468e44bfc7da296e7fde4c3cd6ee2db497430 --- /dev/null +++ b/rgthree-comfy/web/comfyui/feature_import_individual_nodes.js @@ -0,0 +1,52 @@ +import { tryToGetWorkflowDataFromEvent } from "../../rgthree/common/utils_workflow.js"; +import { app } from "../../scripts/app.js"; +import { SERVICE as CONFIG_SERVICE } from "./services/config_service.js"; +app.registerExtension({ + name: "rgthree.ImportIndividualNodes", + async beforeRegisterNodeDef(nodeType, nodeData) { + const onDragOver = nodeType.prototype.onDragOver; + nodeType.prototype.onDragOver = function (e) { + var _a; + let handled = (_a = onDragOver === null || onDragOver === void 0 ? void 0 : onDragOver.apply) === null || _a === void 0 ? void 0 : _a.call(onDragOver, this, [...arguments]); + if (handled != null) { + return handled; + } + return importIndividualNodesInnerOnDragOver(this, e); + }; + const onDragDrop = nodeType.prototype.onDragDrop; + nodeType.prototype.onDragDrop = async function (e) { + var _a; + const alreadyHandled = await ((_a = onDragDrop === null || onDragDrop === void 0 ? void 0 : onDragDrop.apply) === null || _a === void 0 ? void 0 : _a.call(onDragDrop, this, [...arguments])); + if (alreadyHandled) { + return alreadyHandled; + } + return importIndividualNodesInnerOnDragDrop(this, e); + }; + }, +}); +export function importIndividualNodesInnerOnDragOver(node, e) { + var _a; + return ((((_a = node.widgets) === null || _a === void 0 ? void 0 : _a.length) && !!CONFIG_SERVICE.getFeatureValue("import_individual_nodes.enabled")) || + false); +} +export async function importIndividualNodesInnerOnDragDrop(node, e) { + var _a, _b; + if (!((_a = node.widgets) === null || _a === void 0 ? void 0 : _a.length) || !CONFIG_SERVICE.getFeatureValue("import_individual_nodes.enabled")) { + return false; + } + let handled = false; + const { workflow, prompt } = await tryToGetWorkflowDataFromEvent(e); + if (!handled && workflow) { + const exact = (workflow.nodes || []).find((n) => n.id === node.id && n.type === node.type); + if (exact && + ((_b = exact.widgets_values) === null || _b === void 0 ? void 0 : _b.length) && + confirm("Found a node match from embedded workflow (same id & type) in this workflow. Would you like to set the widget values?")) { + node.configure({ widgets_values: [...((exact === null || exact === void 0 ? void 0 : exact.widgets_values) || [])] }); + handled = true; + } + } + if (!handled) { + handled = !confirm("No exact match found in workflow. Would you like to replace the whole workflow?"); + } + return handled; +} diff --git a/rgthree-comfy/web/comfyui/image_comparer.js b/rgthree-comfy/web/comfyui/image_comparer.js new file mode 100644 index 0000000000000000000000000000000000000000..832734cb714fc50619f6aa3a391aa1657a25dca5 --- /dev/null +++ b/rgthree-comfy/web/comfyui/image_comparer.js @@ -0,0 +1,363 @@ +import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js"; +import { RgthreeBaseServerNode } from "./base_node.js"; +import { NodeTypesString } from "./constants.js"; +import { addConnectionLayoutSupport } from "./utils.js"; +import { RgthreeBaseWidget, } from "./utils_widgets.js"; +import { measureText } from "./utils_canvas.js"; +function imageDataToUrl(data) { + return api.apiURL(`/view?filename=${encodeURIComponent(data.filename)}&type=${data.type}&subfolder=${data.subfolder}${app.getPreviewFormatParam()}${app.getRandParam()}`); +} +export class RgthreeImageComparer extends RgthreeBaseServerNode { + constructor(title = RgthreeImageComparer.title) { + super(title); + this.imageIndex = 0; + this.imgs = []; + this.serialize_widgets = true; + this.isPointerDown = false; + this.isPointerOver = false; + this.pointerOverPos = [0, 0]; + this.canvasWidget = null; + this.properties["comparer_mode"] = "Slide"; + } + onExecuted(output) { + var _a; + (_a = super.onExecuted) === null || _a === void 0 ? void 0 : _a.call(this, output); + if ("images" in output) { + this.canvasWidget.value = { + images: (output.images || []).map((d, i) => { + return { + name: i === 0 ? "A" : "B", + selected: true, + url: imageDataToUrl(d), + }; + }), + }; + } + else { + output.a_images = output.a_images || []; + output.b_images = output.b_images || []; + const imagesToChoose = []; + const multiple = output.a_images.length + output.b_images.length > 2; + for (const [i, d] of output.a_images.entries()) { + imagesToChoose.push({ + name: output.a_images.length > 1 || multiple ? `A${i + 1}` : "A", + selected: i === 0, + url: imageDataToUrl(d), + }); + } + for (const [i, d] of output.b_images.entries()) { + imagesToChoose.push({ + name: output.b_images.length > 1 || multiple ? `B${i + 1}` : "B", + selected: i === 0, + url: imageDataToUrl(d), + }); + } + this.canvasWidget.value = { images: imagesToChoose }; + } + } + onSerialize(o) { + var _a; + super.onSerialize && super.onSerialize(o); + for (let [index, widget_value] of (o.widgets_values || []).entries()) { + if (((_a = this.widgets[index]) === null || _a === void 0 ? void 0 : _a.name) === "rgthree_comparer") { + o.widgets_values[index] = this.widgets[index].value.images.map((d) => { + d = { ...d }; + delete d.img; + return d; + }); + } + } + } + onNodeCreated() { + this.canvasWidget = this.addCustomWidget(new RgthreeImageComparerWidget("rgthree_comparer", this)); + this.setSize(this.computeSize()); + this.setDirtyCanvas(true, true); + } + setIsPointerDown(down = this.isPointerDown) { + const newIsDown = down && !!app.canvas.pointer_is_down; + if (this.isPointerDown !== newIsDown) { + this.isPointerDown = newIsDown; + this.setDirtyCanvas(true, false); + } + this.imageIndex = this.isPointerDown ? 1 : 0; + if (this.isPointerDown) { + requestAnimationFrame(() => { + this.setIsPointerDown(); + }); + } + } + onMouseDown(event, pos, graphCanvas) { + var _a; + (_a = super.onMouseDown) === null || _a === void 0 ? void 0 : _a.call(this, event, pos, graphCanvas); + this.setIsPointerDown(true); + } + onMouseEnter(event, pos, graphCanvas) { + var _a; + (_a = super.onMouseEnter) === null || _a === void 0 ? void 0 : _a.call(this, event, pos, graphCanvas); + this.setIsPointerDown(!!app.canvas.pointer_is_down); + this.isPointerOver = true; + } + onMouseLeave(event, pos, graphCanvas) { + var _a; + (_a = super.onMouseLeave) === null || _a === void 0 ? void 0 : _a.call(this, event, pos, graphCanvas); + this.setIsPointerDown(false); + this.isPointerOver = false; + } + onMouseMove(event, pos, graphCanvas) { + var _a; + (_a = super.onMouseMove) === null || _a === void 0 ? void 0 : _a.call(this, event, pos, graphCanvas); + this.pointerOverPos = [...pos]; + this.imageIndex = this.pointerOverPos[0] > this.size[0] / 2 ? 1 : 0; + } + getHelp() { + return ` +

+ The ${this.type.replace("(rgthree)", "")} node compares two images on top of each other. +

+
    +
  • +

    + Notes +

    +
      +
    • + The right-click menu may show image options (Open Image, Save Image, etc.) which will + correspond to the first image (image_a) if clicked on the left-half of the node, or + the second image if on the right half of the node. +

    • +
    +
  • +
  • +

    + Inputs +

    +
      +
    • + image_a Optional. The first image to use to compare. + image_a. +

    • +
    • + image_b Optional. The second image to use to compare. +

    • +
    • + Note image_a and image_b work best when a single + image is provided. However, if each/either are a batch, you can choose which item + from each batch are chosen to be compared. If either image_a or + image_b are not provided, the node will choose the first two from the + provided input if it's a batch, otherwise only show the single image (just as + Preview Image would). +

    • +
    +
  • +
  • +

    + Properties. You can change the following properties (by right-clicking + on the node, and select "Properties" or "Properties Panel" from the menu): +

    +
      +
    • + comparer_mode - Choose between "Slide" and "Click". Defaults to "Slide". +

    • +
    +
  • +
`; + } + static setUp(comfyClass, nodeData) { + RgthreeBaseServerNode.registerForOverride(comfyClass, nodeData, RgthreeImageComparer); + } + static onRegisteredForOverride(comfyClass) { + addConnectionLayoutSupport(RgthreeImageComparer, app, [ + ["Left", "Right"], + ["Right", "Left"], + ]); + setTimeout(() => { + RgthreeImageComparer.category = comfyClass.category; + }); + } +} +RgthreeImageComparer.title = NodeTypesString.IMAGE_COMPARER; +RgthreeImageComparer.type = NodeTypesString.IMAGE_COMPARER; +RgthreeImageComparer.comfyClass = NodeTypesString.IMAGE_COMPARER; +RgthreeImageComparer["@comparer_mode"] = { + type: "combo", + values: ["Slide", "Click"], +}; +class RgthreeImageComparerWidget extends RgthreeBaseWidget { + constructor(name, node) { + super(name); + this.hitAreas = {}; + this.selected = []; + this._value = { images: [] }; + this.node = node; + } + set value(v) { + let cleanedVal; + if (Array.isArray(v)) { + cleanedVal = v.map((d, i) => { + if (!d || typeof d === "string") { + d = { url: d, name: i == 0 ? "A" : "B", selected: true }; + } + return d; + }); + } + else { + cleanedVal = v.images || []; + } + if (cleanedVal.length > 2) { + const hasAAndB = cleanedVal.some((i) => i.name.startsWith("A")) && + cleanedVal.some((i) => i.name.startsWith("B")); + if (!hasAAndB) { + cleanedVal = [cleanedVal[0], cleanedVal[1]]; + } + } + let selected = cleanedVal.filter((d) => d.selected); + if (!selected.length && cleanedVal.length) { + cleanedVal[0].selected = true; + } + selected = cleanedVal.filter((d) => d.selected); + if (selected.length === 1 && cleanedVal.length > 1) { + cleanedVal.find((d) => !d.selected).selected = true; + } + this._value.images = cleanedVal; + selected = cleanedVal.filter((d) => d.selected); + this.setSelected(selected); + } + get value() { + return this._value; + } + setSelected(selected) { + this._value.images.forEach((d) => (d.selected = false)); + this.node.imgs.length = 0; + for (const sel of selected) { + if (!sel.img) { + sel.img = new Image(); + sel.img.src = sel.url; + this.node.imgs.push(sel.img); + } + sel.selected = true; + } + this.selected = selected; + } + draw(ctx, node, width, y) { + var _a; + this.hitAreas = {}; + if (this.value.images.length > 2) { + ctx.textAlign = "left"; + ctx.textBaseline = "top"; + ctx.font = `14px Arial`; + const drawData = []; + const spacing = 5; + let x = 0; + for (const img of this.value.images) { + const width = measureText(ctx, img.name); + drawData.push({ + img, + text: img.name, + x, + width: measureText(ctx, img.name), + }); + x += width + spacing; + } + x = (node.size[0] - (x - spacing)) / 2; + for (const d of drawData) { + ctx.fillStyle = d.img.selected ? "rgba(180, 180, 180, 1)" : "rgba(180, 180, 180, 0.5)"; + ctx.fillText(d.text, x, y); + this.hitAreas[d.text] = { + bounds: [x, y, d.width, 14], + data: d.img, + onDown: this.onSelectionDown, + }; + x += d.width + spacing; + } + y += 20; + } + if (((_a = node.properties) === null || _a === void 0 ? void 0 : _a["comparer_mode"]) === "Click") { + this.drawImage(ctx, this.selected[this.node.isPointerDown ? 1 : 0], y); + } + else { + this.drawImage(ctx, this.selected[0], y); + if (node.isPointerOver) { + this.drawImage(ctx, this.selected[1], y, this.node.pointerOverPos[0]); + } + } + } + onSelectionDown(event, pos, node, bounds) { + const selected = [...this.selected]; + if (bounds === null || bounds === void 0 ? void 0 : bounds.data.name.startsWith("A")) { + selected[0] = bounds.data; + } + else if (bounds === null || bounds === void 0 ? void 0 : bounds.data.name.startsWith("B")) { + selected[1] = bounds.data; + } + this.setSelected(selected); + } + drawImage(ctx, image, y, cropX) { + var _a, _b; + if (!((_a = image === null || image === void 0 ? void 0 : image.img) === null || _a === void 0 ? void 0 : _a.naturalWidth) || !((_b = image === null || image === void 0 ? void 0 : image.img) === null || _b === void 0 ? void 0 : _b.naturalHeight)) { + return; + } + let [nodeWidth, nodeHeight] = this.node.size; + const imageAspect = (image === null || image === void 0 ? void 0 : image.img.naturalWidth) / (image === null || image === void 0 ? void 0 : image.img.naturalHeight); + let height = nodeHeight - y; + const widgetAspect = nodeWidth / height; + let targetWidth, targetHeight; + let offsetX = 0; + if (imageAspect > widgetAspect) { + targetWidth = nodeWidth; + targetHeight = nodeWidth / imageAspect; + } + else { + targetHeight = height; + targetWidth = height * imageAspect; + offsetX = (nodeWidth - targetWidth) / 2; + } + const widthMultiplier = (image === null || image === void 0 ? void 0 : image.img.naturalWidth) / targetWidth; + const sourceX = 0; + const sourceY = 0; + const sourceWidth = cropX != null ? (cropX - offsetX) * widthMultiplier : image === null || image === void 0 ? void 0 : image.img.naturalWidth; + const sourceHeight = image === null || image === void 0 ? void 0 : image.img.naturalHeight; + const destX = (nodeWidth - targetWidth) / 2; + const destY = y + (height - targetHeight) / 2; + const destWidth = cropX != null ? cropX - offsetX : targetWidth; + const destHeight = targetHeight; + ctx.save(); + ctx.beginPath(); + let globalCompositeOperation = ctx.globalCompositeOperation; + if (cropX) { + ctx.rect(destX, destY, destWidth, destHeight); + ctx.clip(); + } + ctx.drawImage(image === null || image === void 0 ? void 0 : image.img, sourceX, sourceY, sourceWidth, sourceHeight, destX, destY, destWidth, destHeight); + if (cropX != null && cropX >= (nodeWidth - targetWidth) / 2 && cropX <= targetWidth + offsetX) { + ctx.beginPath(); + ctx.moveTo(cropX, destY); + ctx.lineTo(cropX, destY + destHeight); + ctx.globalCompositeOperation = "difference"; + ctx.strokeStyle = "rgba(255,255,255, 1)"; + ctx.stroke(); + } + ctx.globalCompositeOperation = globalCompositeOperation; + ctx.restore(); + } + computeSize(width) { + return [width, 20]; + } + serializeValue(serializedNode, widgetIndex) { + const v = []; + for (const data of this._value.images) { + const d = { ...data }; + delete d.img; + v.push(d); + } + return { images: v }; + } +} +app.registerExtension({ + name: "rgthree.ImageComparer", + async beforeRegisterNodeDef(nodeType, nodeData) { + if (nodeData.name === RgthreeImageComparer.type) { + RgthreeImageComparer.setUp(nodeType, nodeData); + } + }, +}); diff --git a/rgthree-comfy/web/comfyui/image_inset_crop.js b/rgthree-comfy/web/comfyui/image_inset_crop.js new file mode 100644 index 0000000000000000000000000000000000000000..17217d0ca3b06229c7e59c559eed38af880c1ae3 --- /dev/null +++ b/rgthree-comfy/web/comfyui/image_inset_crop.js @@ -0,0 +1,59 @@ +import { app } from "../../scripts/app.js"; +import { RgthreeBaseServerNode } from "./base_node.js"; +import { NodeTypesString } from "./constants.js"; +class ImageInsetCrop extends RgthreeBaseServerNode { + constructor(title = ImageInsetCrop.title) { + super(title); + } + onAdded(graph) { + const measurementWidget = this.widgets[0]; + let callback = measurementWidget.callback; + measurementWidget.callback = (...args) => { + this.setWidgetStep(); + callback && callback.apply(measurementWidget, [...args]); + }; + this.setWidgetStep(); + } + configure(info) { + super.configure(info); + this.setWidgetStep(); + } + setWidgetStep() { + const measurementWidget = this.widgets[0]; + for (let i = 1; i <= 4; i++) { + if (measurementWidget.value === "Pixels") { + this.widgets[i].options.step = 80; + this.widgets[i].options.max = ImageInsetCrop.maxResolution; + } + else { + this.widgets[i].options.step = 10; + this.widgets[i].options.max = 99; + } + } + } + async handleAction(action) { + if (action === "Reset Crop") { + for (const widget of this.widgets) { + if (["left", "right", "top", "bottom"].includes(widget.name)) { + widget.value = 0; + } + } + } + } + static setUp(comfyClass, nodeData) { + RgthreeBaseServerNode.registerForOverride(comfyClass, nodeData, ImageInsetCrop); + } +} +ImageInsetCrop.title = NodeTypesString.IMAGE_INSET_CROP; +ImageInsetCrop.type = NodeTypesString.IMAGE_INSET_CROP; +ImageInsetCrop.comfyClass = NodeTypesString.IMAGE_INSET_CROP; +ImageInsetCrop.exposedActions = ["Reset Crop"]; +ImageInsetCrop.maxResolution = 8192; +app.registerExtension({ + name: "rgthree.ImageInsetCrop", + async beforeRegisterNodeDef(nodeType, nodeData, _app) { + if (nodeData.name === NodeTypesString.IMAGE_INSET_CROP) { + ImageInsetCrop.setUp(nodeType, nodeData); + } + }, +}); diff --git a/rgthree-comfy/web/comfyui/label.js b/rgthree-comfy/web/comfyui/label.js new file mode 100644 index 0000000000000000000000000000000000000000..14e1faf3f9e79a9545bc60ae8815a7e612846380 --- /dev/null +++ b/rgthree-comfy/web/comfyui/label.js @@ -0,0 +1,152 @@ +import { app } from "../../scripts/app.js"; +import { RgthreeBaseVirtualNode } from "./base_node.js"; +import { NodeTypesString } from "./constants.js"; +import { rgthree } from "./rgthree.js"; +export class Label extends RgthreeBaseVirtualNode { + constructor(title = Label.title) { + super(title); + this.comfyClass = NodeTypesString.LABEL; + this.resizable = false; + this.properties["fontSize"] = 12; + this.properties["fontFamily"] = "Arial"; + this.properties["fontColor"] = "#ffffff"; + this.properties["textAlign"] = "left"; + this.properties["backgroundColor"] = "transparent"; + this.properties["padding"] = 0; + this.properties["borderRadius"] = 0; + this.color = "#fff0"; + this.bgcolor = "#fff0"; + this.onConstructed(); + } + draw(ctx) { + var _a, _b; + this.flags = this.flags || {}; + this.flags.allow_interaction = !this.flags.pinned; + ctx.save(); + this.color = "#fff0"; + this.bgcolor = "#fff0"; + const fontColor = this.properties["fontColor"] || "#ffffff"; + const backgroundColor = this.properties["backgroundColor"] || ""; + ctx.font = `${Math.max(this.properties["fontSize"] || 0, 1)}px ${(_a = this.properties["fontFamily"]) !== null && _a !== void 0 ? _a : "Arial"}`; + const padding = (_b = Number(this.properties["padding"])) !== null && _b !== void 0 ? _b : 0; + const lines = this.title.replace(/\n*$/, "").split("\n"); + const maxWidth = Math.max(...lines.map((s) => ctx.measureText(s).width)); + this.size[0] = maxWidth + padding * 2; + this.size[1] = this.properties["fontSize"] * lines.length + padding * 2; + if (backgroundColor) { + ctx.beginPath(); + const borderRadius = Number(this.properties["borderRadius"]) || 0; + ctx.roundRect(0, 0, this.size[0], this.size[1], [borderRadius]); + ctx.fillStyle = backgroundColor; + ctx.fill(); + } + ctx.textAlign = "left"; + let textX = padding; + if (this.properties["textAlign"] === "center") { + ctx.textAlign = "center"; + textX = this.size[0] / 2; + } + else if (this.properties["textAlign"] === "right") { + ctx.textAlign = "right"; + textX = this.size[0] - padding; + } + ctx.textBaseline = "top"; + ctx.fillStyle = fontColor; + let currentY = padding; + for (let i = 0; i < lines.length; i++) { + ctx.fillText(lines[i] || " ", textX, currentY); + currentY += this.properties["fontSize"]; + } + ctx.restore(); + } + onDblClick(event, pos, canvas) { + LGraphCanvas.active_canvas.showShowNodePanel(this); + } + onShowCustomPanelInfo(panel) { + var _a, _b; + (_a = panel.querySelector('div.property[data-property="Mode"]')) === null || _a === void 0 ? void 0 : _a.remove(); + (_b = panel.querySelector('div.property[data-property="Color"]')) === null || _b === void 0 ? void 0 : _b.remove(); + } + inResizeCorner(x, y) { + return this.resizable; + } + getHelp() { + return ` +

+ The rgthree-comfy ${this.type.replace("(rgthree)", "")} node allows you to add a floating + label to your workflow. +

+

+ The text shown is the "Title" of the node and you can adjust the the font size, font family, + font color, text alignment as well as a background color, padding, and background border + radius from the node's properties. You can double-click the node to open the properties + panel. +

+

    +
  • +

    + Pro tip #1: You can add multiline text from the properties panel + (because ComfyUI let's you shift + enter there, only). +

    +
  • +
  • +

    + Pro tip #2: You can use ComfyUI's native "pin" option in the + right-click menu to make the label stick to the workflow and clicks to "go through". + You can right-click at any time to unpin. +

    +
  • +
  • +

    + Pro tip #3: Color values are hexidecimal strings, like "#FFFFFF" for + white, or "#660000" for dark red. You can supply a 7th & 8th value (or 5th if using + shorthand) to create a transluscent color. For instance, "#FFFFFF88" is semi-transparent + white. +

    +
  • +
`; + } +} +Label.type = NodeTypesString.LABEL; +Label.title = NodeTypesString.LABEL; +Label.title_mode = LiteGraph.NO_TITLE; +Label.collapsable = false; +Label["@fontSize"] = { type: "number" }; +Label["@fontFamily"] = { type: "string" }; +Label["@fontColor"] = { type: "string" }; +Label["@textAlign"] = { type: "combo", values: ["left", "center", "right"] }; +Label["@backgroundColor"] = { type: "string" }; +Label["@padding"] = { type: "number" }; +Label["@borderRadius"] = { type: "number" }; +const oldDrawNode = LGraphCanvas.prototype.drawNode; +LGraphCanvas.prototype.drawNode = function (node, ctx) { + if (node.constructor === Label) { + node.bgcolor = "transparent"; + node.color = "transparent"; + const v = oldDrawNode.apply(this, arguments); + node.draw(ctx); + return v; + } + const v = oldDrawNode.apply(this, arguments); + return v; +}; +const oldGetNodeOnPos = LGraph.prototype.getNodeOnPos; +LGraph.prototype.getNodeOnPos = function (x, y, nodes_list, margin) { + var _a, _b; + if (nodes_list && + rgthree.processingMouseDown && + ((_a = rgthree.lastAdjustedMouseEvent) === null || _a === void 0 ? void 0 : _a.type.includes("down")) && + ((_b = rgthree.lastAdjustedMouseEvent) === null || _b === void 0 ? void 0 : _b.which) === 1) { + let isDoubleClick = LiteGraph.getTime() - LGraphCanvas.active_canvas.last_mouseclick < 300; + if (!isDoubleClick) { + nodes_list = [...nodes_list].filter((n) => { var _a; return !(n instanceof Label) || !((_a = n.flags) === null || _a === void 0 ? void 0 : _a.pinned); }); + } + } + return oldGetNodeOnPos.apply(this, [x, y, nodes_list, margin]); +}; +app.registerExtension({ + name: "rgthree.Label", + registerCustomNodes() { + Label.setUp(); + }, +}); diff --git a/rgthree-comfy/web/comfyui/menu_auto_nest.js b/rgthree-comfy/web/comfyui/menu_auto_nest.js new file mode 100644 index 0000000000000000000000000000000000000000..da251cee4e9df7d09df9ea84e0d771dd323d5d7e --- /dev/null +++ b/rgthree-comfy/web/comfyui/menu_auto_nest.js @@ -0,0 +1,110 @@ +import { app } from "../../scripts/app.js"; +import { rgthree } from "./rgthree.js"; +import { SERVICE as CONFIG_SERVICE } from "./services/config_service.js"; +const SPECIAL_ENTRIES = [/^(CHOOSE|NONE|DISABLE|OPEN)(\s|$)/i, /^\p{Extended_Pictographic}/gu]; +app.registerExtension({ + name: "rgthree.ContextMenuAutoNest", + async setup() { + const logger = rgthree.newLogSession("[ContextMenuAutoNest]"); + const existingContextMenu = LiteGraph.ContextMenu; + LiteGraph.ContextMenu = function (values, options) { + var _a, _b, _c, _d, _e, _f; + const threshold = CONFIG_SERVICE.getConfigValue("features.menu_auto_nest.threshold", 20); + const enabled = CONFIG_SERVICE.getConfigValue("features.menu_auto_nest.subdirs", false); + let incompatible = !enabled || !!((_a = options === null || options === void 0 ? void 0 : options.extra) === null || _a === void 0 ? void 0 : _a.rgthree_doNotNest); + if (!incompatible) { + if (values.length <= threshold) { + incompatible = `Skipping context menu auto nesting b/c threshold is not met (${threshold})`; + } + if (!((_b = options.parentMenu) === null || _b === void 0 ? void 0 : _b.options.rgthree_originalCallback)) { + if (!(options === null || options === void 0 ? void 0 : options.callback)) { + incompatible = `Skipping context menu auto nesting b/c a callback was expected.`; + } + else if (values.some((i) => typeof i !== "string")) { + incompatible = `Skipping context menu auto nesting b/c not all values were strings.`; + } + } + } + if (incompatible) { + if (enabled) { + const [n, v] = logger.infoParts("Skipping context menu auto nesting for incompatible menu."); + (_c = console[n]) === null || _c === void 0 ? void 0 : _c.call(console, ...v); + } + return existingContextMenu.apply(this, [...arguments]); + } + const folders = {}; + const specialOps = []; + const folderless = []; + for (const value of values) { + if (!value) { + folderless.push(value); + continue; + } + const newValue = typeof value === "string" ? { content: value } : Object.assign({}, value); + newValue.rgthree_originalValue = value.rgthree_originalValue || value; + const valueContent = newValue.content || ''; + const splitBy = valueContent.indexOf("/") > -1 ? "/" : "\\"; + const valueSplit = valueContent.split(splitBy); + if (valueSplit.length > 1) { + const key = valueSplit.shift(); + newValue.content = valueSplit.join(splitBy); + folders[key] = folders[key] || []; + folders[key].push(newValue); + } + else if (SPECIAL_ENTRIES.some((r) => r.test(valueContent))) { + specialOps.push(newValue); + } + else { + folderless.push(newValue); + } + } + const foldersCount = Object.values(folders).length; + if (foldersCount > 0) { + options.rgthree_originalCallback = + options.rgthree_originalCallback || + ((_d = options.parentMenu) === null || _d === void 0 ? void 0 : _d.options.rgthree_originalCallback) || + options.callback; + const oldCallback = options.rgthree_originalCallback; + options.callback = undefined; + const newCallback = (item, options, event, parentMenu, node) => { + oldCallback === null || oldCallback === void 0 ? void 0 : oldCallback(item === null || item === void 0 ? void 0 : item.rgthree_originalValue, options, event, undefined, node); + }; + const [n, v] = logger.infoParts(`Nested folders found (${foldersCount}).`); + (_e = console[n]) === null || _e === void 0 ? void 0 : _e.call(console, ...v); + const newValues = []; + for (const [folderName, folderValues] of Object.entries(folders)) { + newValues.push({ + content: `📁 ${folderName}`, + has_submenu: true, + callback: () => { + }, + submenu: { + options: folderValues.map((value) => { + value.callback = newCallback; + return value; + }), + }, + }); + } + values = [].concat(specialOps.map((f) => { + if (typeof f === "string") { + f = { content: f }; + } + f.callback = newCallback; + return f; + }), newValues, folderless.map((f) => { + if (typeof f === "string") { + f = { content: f }; + } + f.callback = newCallback; + return f; + })); + } + if (options.scale == null) { + options.scale = Math.max(((_f = app.canvas.ds) === null || _f === void 0 ? void 0 : _f.scale) || 1, 1); + } + return existingContextMenu.call(this, values, options); + }; + LiteGraph.ContextMenu.prototype = existingContextMenu.prototype; + }, +}); diff --git a/rgthree-comfy/web/comfyui/menu_copy_image.js b/rgthree-comfy/web/comfyui/menu_copy_image.js new file mode 100644 index 0000000000000000000000000000000000000000..4a0be8ffc7a9d3c4b222266ada123294299cec45 --- /dev/null +++ b/rgthree-comfy/web/comfyui/menu_copy_image.js @@ -0,0 +1,61 @@ +import { app } from "../../scripts/app.js"; +const clipboardSupportedPromise = new Promise(async (resolve) => { + try { + const result = await navigator.permissions.query({ name: "clipboard-write" }); + resolve(result.state === "granted"); + return; + } + catch (e) { + try { + if (!navigator.clipboard.write) { + throw new Error(); + } + new ClipboardItem({ "image/png": new Blob([], { type: "image/png" }) }); + resolve(true); + return; + } + catch (e) { + resolve(false); + } + } +}); +app.registerExtension({ + name: "rgthree.CopyImageToClipboard", + async beforeRegisterNodeDef(nodeType, nodeData) { + if (nodeData.name.toLowerCase().includes("image")) { + if (await clipboardSupportedPromise) { + const getExtraMenuOptions = nodeType.prototype.getExtraMenuOptions; + nodeType.prototype.getExtraMenuOptions = function (canvas, options) { + var _a; + getExtraMenuOptions ? getExtraMenuOptions.apply(this, arguments) : undefined; + if ((_a = this.imgs) === null || _a === void 0 ? void 0 : _a.length) { + let img = this.imgs[this.imageIndex || 0] || this.imgs[this.overIndex || 0] || this.imgs[0]; + const foundIdx = options.findIndex((option) => { var _a; return (_a = option === null || option === void 0 ? void 0 : option.content) === null || _a === void 0 ? void 0 : _a.includes("Copy Image"); }); + if (img && foundIdx === -1) { + const menuItem = { + content: "Copy Image (rgthree)", + callback: () => { + const canvas = document.createElement("canvas"); + const ctx = canvas.getContext("2d"); + canvas.width = img.naturalWidth; + canvas.height = img.naturalHeight; + ctx.drawImage(img, 0, 0, img.naturalWidth, img.naturalHeight); + canvas.toBlob((blob) => { + navigator.clipboard.write([new ClipboardItem({ "image/png": blob })]); + }); + }, + }; + let idx = options.findIndex((option) => { var _a; return (_a = option === null || option === void 0 ? void 0 : option.content) === null || _a === void 0 ? void 0 : _a.includes("Open Image"); }) + 1; + if (idx != null) { + options.splice(idx, 0, menuItem); + } + else { + options.unshift(menuItem); + } + } + } + }; + } + } + }, +}); diff --git a/rgthree-comfy/web/comfyui/menu_queue_node.js b/rgthree-comfy/web/comfyui/menu_queue_node.js new file mode 100644 index 0000000000000000000000000000000000000000..2f28fb3ab7f838398f3a6befff6d6ca08a0604f1 --- /dev/null +++ b/rgthree-comfy/web/comfyui/menu_queue_node.js @@ -0,0 +1,69 @@ +import { app } from "../../scripts/app.js"; +import { rgthree } from "./rgthree.js"; +import { SERVICE as CONFIG_SERVICE } from "./services/config_service.js"; +function getOutputNodes(nodes) { + return ((nodes === null || nodes === void 0 ? void 0 : nodes.filter((n) => { + var _a; + return (n.mode != LiteGraph.NEVER && + ((_a = n.constructor.nodeData) === null || _a === void 0 ? void 0 : _a.output_node)); + })) || []); +} +function showQueueNodesMenuIfOutputNodesAreSelected(existingOptions) { + if (CONFIG_SERVICE.getConfigValue("features.menu_queue_selected_nodes") === false) { + return; + } + const outputNodes = getOutputNodes(Object.values(app.canvas.selected_nodes)); + const menuItem = { + content: `Queue Selected Output Nodes (rgthree)  `, + className: "rgthree-contextmenu-item", + callback: () => { + rgthree.queueOutputNodes(outputNodes.map((n) => n.id)); + }, + disabled: !outputNodes.length, + }; + let idx = existingOptions.findIndex((o) => (o === null || o === void 0 ? void 0 : o.content) === "Outputs") + 1; + idx = idx || existingOptions.findIndex((o) => (o === null || o === void 0 ? void 0 : o.content) === "Align") + 1; + idx = idx || 3; + existingOptions.splice(idx, 0, menuItem); +} +function showQueueGroupNodesMenuIfGroupIsSelected(existingOptions) { + if (CONFIG_SERVICE.getConfigValue("features.menu_queue_selected_nodes") === false) { + return; + } + const group = rgthree.lastAdjustedMouseEvent && + app.graph.getGroupOnPos(rgthree.lastAdjustedMouseEvent.canvasX, rgthree.lastAdjustedMouseEvent.canvasY); + const outputNodes = group && getOutputNodes(group._nodes); + const menuItem = { + content: `Queue Group Output Nodes (rgthree)  `, + className: "rgthree-contextmenu-item", + callback: () => { + outputNodes && rgthree.queueOutputNodes(outputNodes.map((n) => n.id)); + }, + disabled: !(outputNodes === null || outputNodes === void 0 ? void 0 : outputNodes.length), + }; + let idx = existingOptions.findIndex((o) => { var _a; return (_a = o === null || o === void 0 ? void 0 : o.content) === null || _a === void 0 ? void 0 : _a.startsWith("Queue Selected "); }) + 1; + idx = idx || existingOptions.findIndex((o) => (o === null || o === void 0 ? void 0 : o.content) === "Outputs") + 1; + idx = idx || existingOptions.findIndex((o) => (o === null || o === void 0 ? void 0 : o.content) === "Align") + 1; + idx = idx || 3; + existingOptions.splice(idx, 0, menuItem); +} +app.registerExtension({ + name: "rgthree.QueueNode", + async beforeRegisterNodeDef(nodeType, nodeData) { + const getExtraMenuOptions = nodeType.prototype.getExtraMenuOptions; + nodeType.prototype.getExtraMenuOptions = function (canvas, options) { + getExtraMenuOptions ? getExtraMenuOptions.apply(this, arguments) : undefined; + showQueueNodesMenuIfOutputNodesAreSelected(options); + showQueueGroupNodesMenuIfGroupIsSelected(options); + }; + }, + async setup() { + const getCanvasMenuOptions = LGraphCanvas.prototype.getCanvasMenuOptions; + LGraphCanvas.prototype.getCanvasMenuOptions = function (...args) { + const options = getCanvasMenuOptions.apply(this, [...args]); + showQueueNodesMenuIfOutputNodesAreSelected(options); + showQueueGroupNodesMenuIfGroupIsSelected(options); + return options; + }; + }, +}); diff --git a/rgthree-comfy/web/comfyui/muter.js b/rgthree-comfy/web/comfyui/muter.js new file mode 100644 index 0000000000000000000000000000000000000000..f18210f61bfb2c1ba30ae871c61411b1435744ef --- /dev/null +++ b/rgthree-comfy/web/comfyui/muter.js @@ -0,0 +1,45 @@ +import { app } from "../../scripts/app.js"; +import { BaseNodeModeChanger } from "./base_node_mode_changer.js"; +import { NodeTypesString } from "./constants.js"; +const MODE_MUTE = 2; +const MODE_ALWAYS = 0; +class MuterNode extends BaseNodeModeChanger { + constructor(title = MuterNode.title) { + super(title); + this.comfyClass = NodeTypesString.FAST_MUTER; + this.modeOn = MODE_ALWAYS; + this.modeOff = MODE_MUTE; + this.onConstructed(); + } + async handleAction(action) { + if (action === "Mute all") { + for (const widget of this.widgets) { + this.forceWidgetOff(widget, true); + } + } + else if (action === "Enable all") { + for (const widget of this.widgets) { + this.forceWidgetOn(widget, true); + } + } + else if (action === "Toggle all") { + for (const widget of this.widgets) { + this.forceWidgetToggle(widget, true); + } + } + } +} +MuterNode.exposedActions = ["Mute all", "Enable all", "Toggle all"]; +MuterNode.type = NodeTypesString.FAST_MUTER; +MuterNode.title = NodeTypesString.FAST_MUTER; +app.registerExtension({ + name: "rgthree.Muter", + registerCustomNodes() { + MuterNode.setUp(); + }, + loadedGraphNode(node) { + if (node.type == MuterNode.title) { + node._tempWidth = node.size[0]; + } + }, +}); diff --git a/rgthree-comfy/web/comfyui/node_collector.js b/rgthree-comfy/web/comfyui/node_collector.js new file mode 100644 index 0000000000000000000000000000000000000000..c4f6e0fef4826c8756f38a156496604ffb883a43 --- /dev/null +++ b/rgthree-comfy/web/comfyui/node_collector.js @@ -0,0 +1,118 @@ +import { app } from "../../scripts/app.js"; +import { addConnectionLayoutSupport } from "./utils.js"; +import { wait } from "../../rgthree/common/shared_utils.js"; +import { ComfyWidgets } from "../../scripts/widgets.js"; +import { BaseCollectorNode } from "./base_node_collector.js"; +import { NodeTypesString } from "./constants.js"; +class CollectorNode extends BaseCollectorNode { + constructor(title = CollectorNode.title) { + super(title); + this.comfyClass = NodeTypesString.NODE_COLLECTOR; + this.onConstructed(); + } + onConstructed() { + this.addOutput("Output", "*"); + return super.onConstructed(); + } + configure(info) { + var _a; + if ((_a = info.outputs) === null || _a === void 0 ? void 0 : _a.length) { + info.outputs.length = 1; + } + super.configure(info); + } +} +CollectorNode.type = NodeTypesString.NODE_COLLECTOR; +CollectorNode.title = NodeTypesString.NODE_COLLECTOR; +class CombinerNode extends CollectorNode { + constructor(title = CombinerNode.title) { + super(title); + const note = ComfyWidgets["STRING"](this, "last_seed", ["STRING", { multiline: true }], app).widget; + note.inputEl.value = + 'The Node Combiner has been renamed to Node Collector. You can right-click and select "Update to Node Collector" to attempt to automatically update.'; + note.inputEl.readOnly = true; + note.inputEl.style.backgroundColor = "#332222"; + note.inputEl.style.fontWeight = "bold"; + note.inputEl.style.fontStyle = "italic"; + note.inputEl.style.opacity = "0.8"; + this.getExtraMenuOptions = (_, options) => { + options.splice(options.length - 1, 0, { + content: "‼️ Update to Node Collector", + callback: (_value, _options, _event, _parentMenu, _node) => { + updateCombinerToCollector(this); + }, + }); + }; + } + configure(info) { + super.configure(info); + if (this.title != CombinerNode.title && !this.title.startsWith("‼️")) { + this.title = "‼️ " + this.title; + } + } +} +CombinerNode.legacyType = "Node Combiner (rgthree)"; +CombinerNode.title = "‼️ Node Combiner [DEPRECATED]"; +async function updateCombinerToCollector(node) { + if (node.type === CombinerNode.legacyType) { + const newNode = new CollectorNode(); + if (node.title != CombinerNode.title) { + newNode.title = node.title.replace("‼️ ", ""); + } + newNode.pos = [...node.pos]; + newNode.size = [...node.size]; + newNode.properties = { ...node.properties }; + const links = []; + for (const [index, output] of node.outputs.entries()) { + for (const linkId of output.links || []) { + const link = app.graph.links[linkId]; + if (!link) + continue; + const targetNode = app.graph.getNodeById(link.target_id); + links.push({ node: newNode, slot: index, targetNode, targetSlot: link.target_slot }); + } + } + for (const [index, input] of node.inputs.entries()) { + const linkId = input.link; + if (linkId) { + const link = app.graph.links[linkId]; + const originNode = app.graph.getNodeById(link.origin_id); + links.push({ + node: originNode, + slot: link.origin_slot, + targetNode: newNode, + targetSlot: index, + }); + } + } + app.graph.add(newNode); + await wait(); + for (const link of links) { + link.node.connect(link.slot, link.targetNode, link.targetSlot); + } + await wait(); + app.graph.remove(node); + } +} +app.registerExtension({ + name: "rgthree.NodeCollector", + registerCustomNodes() { + addConnectionLayoutSupport(CollectorNode, app, [ + ["Left", "Right"], + ["Right", "Left"], + ]); + LiteGraph.registerNodeType(CollectorNode.title, CollectorNode); + CollectorNode.category = CollectorNode._category; + }, +}); +app.registerExtension({ + name: "rgthree.NodeCombiner", + registerCustomNodes() { + addConnectionLayoutSupport(CombinerNode, app, [ + ["Left", "Right"], + ["Right", "Left"], + ]); + LiteGraph.registerNodeType(CombinerNode.legacyType, CombinerNode); + CombinerNode.category = CombinerNode._category; + }, +}); diff --git a/rgthree-comfy/web/comfyui/node_mode_relay.js b/rgthree-comfy/web/comfyui/node_mode_relay.js new file mode 100644 index 0000000000000000000000000000000000000000..8343e6ca6ac40f7589796d50039d98b1e2b46e58 --- /dev/null +++ b/rgthree-comfy/web/comfyui/node_mode_relay.js @@ -0,0 +1,212 @@ +import { app } from "../../scripts/app.js"; +import { PassThroughFollowing, addConnectionLayoutSupport, getConnectedInputNodesAndFilterPassThroughs, getConnectedOutputNodesAndFilterPassThroughs, } from "./utils.js"; +import { wait } from "../../rgthree/common/shared_utils.js"; +import { BaseCollectorNode } from "./base_node_collector.js"; +import { NodeTypesString, stripRgthree } from "./constants.js"; +import { fitString } from "./utils_canvas.js"; +import { rgthree } from "./rgthree.js"; +const MODE_ALWAYS = 0; +const MODE_MUTE = 2; +const MODE_BYPASS = 4; +const MODE_REPEATS = [MODE_MUTE, MODE_BYPASS]; +const MODE_NOTHING = -99; +const MODE_TO_OPTION = new Map([ + [MODE_ALWAYS, "ACTIVE"], + [MODE_MUTE, "MUTE"], + [MODE_BYPASS, "BYPASS"], + [MODE_NOTHING, "NOTHING"], +]); +const OPTION_TO_MODE = new Map([ + ["ACTIVE", MODE_ALWAYS], + ["MUTE", MODE_MUTE], + ["BYPASS", MODE_BYPASS], + ["NOTHING", MODE_NOTHING], +]); +const MODE_TO_PROPERTY = new Map([ + [MODE_MUTE, "on_muted_inputs"], + [MODE_BYPASS, "on_bypassed_inputs"], + [MODE_ALWAYS, "on_any_active_inputs"], +]); +const logger = rgthree.newLogSession("[NodeModeRelay]"); +class NodeModeRelay extends BaseCollectorNode { + constructor(title) { + super(title); + this.inputsPassThroughFollowing = PassThroughFollowing.ALL; + this.comfyClass = NodeTypesString.NODE_MODE_RELAY; + this.properties["on_muted_inputs"] = "MUTE"; + this.properties["on_bypassed_inputs"] = "BYPASS"; + this.properties["on_any_active_inputs"] = "ACTIVE"; + this.onConstructed(); + } + onConstructed() { + this.addOutput("REPEATER", "_NODE_REPEATER_", { + color_on: "#Fc0", + color_off: "#a80", + shape: LiteGraph.ARROW_SHAPE, + }); + setTimeout(() => { + this.stabilize(); + }, 500); + return super.onConstructed(); + } + onModeChange(from, to) { + var _a; + super.onModeChange(from, to); + if (this.inputs.length <= 1 && !this.isInputConnected(0) && this.isAnyOutputConnected()) { + const [n, v] = logger.infoParts(`Mode change without any inputs; relaying our mode.`); + (_a = console[n]) === null || _a === void 0 ? void 0 : _a.call(console, ...v); + this.dispatchModeToRepeater(this.mode); + } + } + configure(info) { + var _a; + if ((_a = info.outputs) === null || _a === void 0 ? void 0 : _a.length) { + info.outputs.length = 1; + } + super.configure(info); + } + onDrawForeground(ctx, canvas) { + var _a; + if ((_a = this.flags) === null || _a === void 0 ? void 0 : _a.collapsed) { + return; + } + if (this.properties["on_muted_inputs"] !== "MUTE" || + this.properties["on_bypassed_inputs"] !== "BYPASS" || + this.properties["on_any_active_inputs"] != "ACTIVE") { + let margin = 15; + ctx.textAlign = "left"; + let label = `*(MUTE > ${this.properties["on_muted_inputs"]}, `; + label += `BYPASS > ${this.properties["on_bypassed_inputs"]}, `; + label += `ACTIVE > ${this.properties["on_any_active_inputs"]})`; + ctx.fillStyle = LiteGraph.WIDGET_SECONDARY_TEXT_COLOR; + const oldFont = ctx.font; + ctx.font = "italic " + (LiteGraph.NODE_SUBTEXT_SIZE - 2) + "px Arial"; + ctx.fillText(fitString(ctx, label, this.size[0] - 20), 15, this.size[1] - 6); + ctx.font = oldFont; + } + } + computeSize(out) { + let size = super.computeSize(out); + if (this.properties["on_muted_inputs"] !== "MUTE" || + this.properties["on_bypassed_inputs"] !== "BYPASS" || + this.properties["on_any_active_inputs"] != "ACTIVE") { + size[1] += 17; + } + return size; + } + onConnectOutput(outputIndex, inputType, inputSlot, inputNode, inputIndex) { + var _a, _b; + let canConnect = (_a = super.onConnectOutput) === null || _a === void 0 ? void 0 : _a.call(this, outputIndex, inputType, inputSlot, inputNode, inputIndex); + let nextNode = (_b = getConnectedOutputNodesAndFilterPassThroughs(this, inputNode)[0]) !== null && _b !== void 0 ? _b : inputNode; + return canConnect && nextNode.type === NodeTypesString.NODE_MODE_REPEATER; + } + onConnectionsChange(type, slotIndex, isConnected, link_info, ioSlot) { + super.onConnectionsChange(type, slotIndex, isConnected, link_info, ioSlot); + setTimeout(() => { + this.stabilize(); + }, 500); + } + stabilize() { + if (!this.graph || !this.isAnyOutputConnected() || !this.isInputConnected(0)) { + return; + } + const inputNodes = getConnectedInputNodesAndFilterPassThroughs(this, this, -1, this.inputsPassThroughFollowing); + let mode = undefined; + for (const inputNode of inputNodes) { + if (mode === undefined) { + mode = inputNode.mode; + } + else if (mode === inputNode.mode && MODE_REPEATS.includes(mode)) { + continue; + } + else if (inputNode.mode === MODE_ALWAYS || mode === MODE_ALWAYS) { + mode = MODE_ALWAYS; + } + else { + mode = null; + } + } + this.dispatchModeToRepeater(mode); + setTimeout(() => { + this.stabilize(); + }, 500); + } + dispatchModeToRepeater(mode) { + var _a, _b; + if (mode != null) { + const propertyVal = (_a = this.properties) === null || _a === void 0 ? void 0 : _a[MODE_TO_PROPERTY.get(mode) || ""]; + const newMode = OPTION_TO_MODE.get(propertyVal); + mode = (newMode !== null ? newMode : mode); + if (mode !== null && mode !== MODE_NOTHING) { + if ((_b = this.outputs) === null || _b === void 0 ? void 0 : _b.length) { + const outputNodes = getConnectedOutputNodesAndFilterPassThroughs(this); + for (const outputNode of outputNodes) { + outputNode.mode = mode; + wait(16).then(() => { + outputNode.setDirtyCanvas(true, true); + }); + } + } + } + } + } + getHelp() { + return ` +

+ This node will relay its input nodes' modes (Mute, Bypass, or Active) to a connected + ${stripRgthree(NodeTypesString.NODE_MODE_REPEATER)} (which would then repeat that mode + change to all of its inputs). +

+
    +
  • + When all connected input nodes are muted, the relay will set a connected repeater to + mute (by default). +

  • +
  • + When all connected input nodes are bypassed, the relay will set a connected repeater to + bypass (by default). +

  • +
  • + When any connected input nodes are active, the relay will set a connected repeater to + active (by default). +

  • +
  • + If no inputs are connected, the relay will set a connected repeater to its mode when + its own mode is changed. Note, if any inputs are connected, then the above + will occur and the Relay's mode does not matter. +

  • +
+

+ Note, you can change which signals get sent on the above in the Properties. + For instance, you could configure an inverse relay which will send a MUTE when any of its + inputs are active (instead of sending an ACTIVE signal), and send an ACTIVE signal when all + of its inputs are muted (instead of sending a MUTE signal), etc. +

+ `; + } +} +NodeModeRelay.type = NodeTypesString.NODE_MODE_RELAY; +NodeModeRelay.title = NodeTypesString.NODE_MODE_RELAY; +NodeModeRelay["@on_muted_inputs"] = { + type: "combo", + values: ["MUTE", "ACTIVE", "BYPASS", "NOTHING"], +}; +NodeModeRelay["@on_bypassed_inputs"] = { + type: "combo", + values: ["BYPASS", "ACTIVE", "MUTE", "NOTHING"], +}; +NodeModeRelay["@on_any_active_inputs"] = { + type: "combo", + values: ["BYPASS", "ACTIVE", "MUTE", "NOTHING"], +}; +app.registerExtension({ + name: "rgthree.NodeModeRepeaterHelper", + registerCustomNodes() { + addConnectionLayoutSupport(NodeModeRelay, app, [ + ["Left", "Right"], + ["Right", "Left"], + ]); + LiteGraph.registerNodeType(NodeModeRelay.type, NodeModeRelay); + NodeModeRelay.category = NodeModeRelay._category; + }, +}); diff --git a/rgthree-comfy/web/comfyui/node_mode_repeater.js b/rgthree-comfy/web/comfyui/node_mode_repeater.js new file mode 100644 index 0000000000000000000000000000000000000000..674de8ffaab6f3170b7b8a7bf4e1f774c00b55f2 --- /dev/null +++ b/rgthree-comfy/web/comfyui/node_mode_repeater.js @@ -0,0 +1,152 @@ +import { app } from "../../scripts/app.js"; +import { BaseCollectorNode } from "./base_node_collector.js"; +import { NodeTypesString, stripRgthree } from "./constants.js"; +import { PassThroughFollowing, addConnectionLayoutSupport, getConnectedInputNodesAndFilterPassThroughs, getConnectedOutputNodesAndFilterPassThroughs, } from "./utils.js"; +class NodeModeRepeater extends BaseCollectorNode { + constructor(title) { + super(title); + this.inputsPassThroughFollowing = PassThroughFollowing.ALL; + this.comfyClass = NodeTypesString.NODE_MODE_REPEATER; + this.hasRelayInput = false; + this.hasTogglerOutput = false; + this.onConstructed(); + } + onConstructed() { + this.addOutput("OPT_CONNECTION", "*", { + color_on: "#Fc0", + color_off: "#a80", + }); + return super.onConstructed(); + } + configure(info) { + var _a; + if ((_a = info.outputs) === null || _a === void 0 ? void 0 : _a.length) { + info.outputs.length = 1; + } + super.configure(info); + } + onConnectOutput(outputIndex, inputType, inputSlot, inputNode, inputIndex) { + let canConnect = !this.hasRelayInput; + canConnect = + canConnect && super.onConnectOutput(outputIndex, inputType, inputSlot, inputNode, inputIndex); + let nextNode = getConnectedOutputNodesAndFilterPassThroughs(this, inputNode)[0] || inputNode; + return (canConnect && + [ + NodeTypesString.FAST_MUTER, + NodeTypesString.FAST_BYPASSER, + NodeTypesString.NODE_COLLECTOR, + NodeTypesString.FAST_ACTIONS_BUTTON, + NodeTypesString.REROUTE, + NodeTypesString.RANDOM_UNMUTER, + ].includes(nextNode.type || "")); + } + onConnectInput(inputIndex, outputType, outputSlot, outputNode, outputIndex) { + var _a; + let canConnect = (_a = super.onConnectInput) === null || _a === void 0 ? void 0 : _a.call(this, inputIndex, outputType, outputSlot, outputNode, outputIndex); + let nextNode = getConnectedOutputNodesAndFilterPassThroughs(this, outputNode)[0] || outputNode; + const isNextNodeRelay = nextNode.type === NodeTypesString.NODE_MODE_RELAY; + return canConnect && (!isNextNodeRelay || !this.hasTogglerOutput); + } + onConnectionsChange(type, slotIndex, isConnected, linkInfo, ioSlot) { + super.onConnectionsChange(type, slotIndex, isConnected, linkInfo, ioSlot); + let hasTogglerOutput = false; + let hasRelayInput = false; + const outputNodes = getConnectedOutputNodesAndFilterPassThroughs(this); + for (const outputNode of outputNodes) { + if ((outputNode === null || outputNode === void 0 ? void 0 : outputNode.type) === NodeTypesString.FAST_MUTER || + (outputNode === null || outputNode === void 0 ? void 0 : outputNode.type) === NodeTypesString.FAST_BYPASSER) { + hasTogglerOutput = true; + break; + } + } + const inputNodes = getConnectedInputNodesAndFilterPassThroughs(this); + for (const [index, inputNode] of inputNodes.entries()) { + if ((inputNode === null || inputNode === void 0 ? void 0 : inputNode.type) === NodeTypesString.NODE_MODE_RELAY) { + if (hasTogglerOutput) { + console.log(`Can't be connected to a Relay if also output to a toggler.`); + this.disconnectInput(index); + } + else { + hasRelayInput = true; + if (this.inputs[index]) { + this.inputs[index].color_on = "#FC0"; + this.inputs[index].color_off = "#a80"; + } + } + } + else { + inputNode.mode = this.mode; + } + } + this.hasTogglerOutput = hasTogglerOutput; + this.hasRelayInput = hasRelayInput; + if (this.hasRelayInput) { + if (this.outputs[0]) { + this.disconnectOutput(0); + this.removeOutput(0); + } + } + else if (!this.outputs[0]) { + this.addOutput("OPT_CONNECTION", "*", { + color_on: "#Fc0", + color_off: "#a80", + }); + } + } + onModeChange(from, to) { + var _a, _b; + super.onModeChange(from, to); + const linkedNodes = getConnectedInputNodesAndFilterPassThroughs(this).filter((node) => node.type !== NodeTypesString.NODE_MODE_RELAY); + if (linkedNodes.length) { + for (const node of linkedNodes) { + if (node.type !== NodeTypesString.NODE_MODE_RELAY) { + node.mode = this.mode; + } + } + } + else if ((_a = app.graph._groups) === null || _a === void 0 ? void 0 : _a.length) { + for (const group of app.graph._groups) { + group.recomputeInsideNodes(); + if ((_b = group._nodes) === null || _b === void 0 ? void 0 : _b.includes(this)) { + for (const node of group._nodes) { + node.mode = this.mode; + } + } + } + } + } + getHelp() { + return ` +

+ When this node's mode (Mute, Bypass, Active) changes, it will "repeat" that mode to all + connected input nodes, or, if there are no connected nodes AND it is overlapping a group, + "repeat" it's mode to all nodes in that group. +

+
    +
  • + Optionally, connect this mode's output to a ${stripRgthree(NodeTypesString.FAST_MUTER)} + or ${stripRgthree(NodeTypesString.FAST_BYPASSER)} for a single toggle to quickly + mute/bypass all its connected nodes. +

  • +
  • + Optionally, connect a ${stripRgthree(NodeTypesString.NODE_MODE_RELAY)} to this nodes + inputs to have it automatically toggle its mode. If connected, this will always take + precedence (and disconnect any connected fast togglers). +

  • +
+ `; + } +} +NodeModeRepeater.type = NodeTypesString.NODE_MODE_REPEATER; +NodeModeRepeater.title = NodeTypesString.NODE_MODE_REPEATER; +app.registerExtension({ + name: "rgthree.NodeModeRepeater", + registerCustomNodes() { + addConnectionLayoutSupport(NodeModeRepeater, app, [ + ["Left", "Right"], + ["Right", "Left"], + ]); + LiteGraph.registerNodeType(NodeModeRepeater.type, NodeModeRepeater); + NodeModeRepeater.category = NodeModeRepeater._category; + }, +}); diff --git a/rgthree-comfy/web/comfyui/power_lora_loader.js b/rgthree-comfy/web/comfyui/power_lora_loader.js new file mode 100644 index 0000000000000000000000000000000000000000..0c5da1832403d44cce4beb591c2690b804b455be --- /dev/null +++ b/rgthree-comfy/web/comfyui/power_lora_loader.js @@ -0,0 +1,562 @@ +var _a; +import { app } from "../../scripts/app.js"; +import { RgthreeBaseServerNode } from "./base_node.js"; +import { rgthree } from "./rgthree.js"; +import { addConnectionLayoutSupport } from "./utils.js"; +import { NodeTypesString } from "./constants.js"; +import { drawInfoIcon, drawNumberWidgetPart, drawRoundedRectangle, drawTogglePart, fitString, isLowQuality, } from "./utils_canvas.js"; +import { RgthreeBaseWidget, RgthreeBetterButtonWidget, RgthreeDividerWidget, } from "./utils_widgets.js"; +import { rgthreeApi } from "../../rgthree/common/rgthree_api.js"; +import { showLoraChooser } from "./utils_menu.js"; +import { moveArrayItem, removeArrayItem } from "../../rgthree/common/shared_utils.js"; +import { RgthreeInfoDialog } from "./dialog_info.js"; +import { SERVICE as MODEL_INFO_SERVICE } from "../../rgthree/common/model_info_service.js"; +const PROP_LABEL_SHOW_STRENGTHS = "Show Strengths"; +const PROP_LABEL_SHOW_STRENGTHS_STATIC = `@${PROP_LABEL_SHOW_STRENGTHS}`; +const PROP_VALUE_SHOW_STRENGTHS_SINGLE = "Single Strength"; +const PROP_VALUE_SHOW_STRENGTHS_SEPARATE = "Separate Model & Clip"; +class RgthreePowerLoraLoader extends RgthreeBaseServerNode { + constructor(title = NODE_CLASS.title) { + super(title); + this.serialize_widgets = true; + this.logger = rgthree.newLogSession(`[Power Lora Stack]`); + this.loraWidgetsCounter = 0; + this.widgetButtonSpacer = null; + this.properties[PROP_LABEL_SHOW_STRENGTHS] = PROP_VALUE_SHOW_STRENGTHS_SINGLE; + rgthreeApi.getLoras(); + } + configure(info) { + var _b; + while ((_b = this.widgets) === null || _b === void 0 ? void 0 : _b.length) + this.removeWidget(0); + this.widgetButtonSpacer = null; + super.configure(info); + this._tempWidth = this.size[0]; + this._tempHeight = this.size[1]; + for (const widgetValue of info.widgets_values || []) { + if ((widgetValue === null || widgetValue === void 0 ? void 0 : widgetValue.lora) !== undefined) { + const widget = this.addNewLoraWidget(); + widget.value = { ...widgetValue }; + } + } + this.addNonLoraWidgets(); + this.size[0] = this._tempWidth; + this.size[1] = Math.max(this._tempHeight, this.computeSize()[1]); + } + onNodeCreated() { + var _b; + (_b = super.onNodeCreated) === null || _b === void 0 ? void 0 : _b.call(this); + this.addNonLoraWidgets(); + const computed = this.computeSize(); + this.size = this.size || [0, 0]; + this.size[0] = Math.max(this.size[0], computed[0]); + this.size[1] = Math.max(this.size[1], computed[1]); + this.setDirtyCanvas(true, true); + } + addNewLoraWidget(lora) { + this.loraWidgetsCounter++; + const widget = this.addCustomWidget(new PowerLoraLoaderWidget("lora_" + this.loraWidgetsCounter)); + if (lora) + widget.setLora(lora); + if (this.widgetButtonSpacer) { + moveArrayItem(this.widgets, widget, this.widgets.indexOf(this.widgetButtonSpacer)); + } + return widget; + } + addNonLoraWidgets() { + moveArrayItem(this.widgets, this.addCustomWidget(new RgthreeDividerWidget({ marginTop: 4, marginBottom: 0, thickness: 0 })), 0); + moveArrayItem(this.widgets, this.addCustomWidget(new PowerLoraLoaderHeaderWidget()), 1); + this.widgetButtonSpacer = this.addCustomWidget(new RgthreeDividerWidget({ marginTop: 4, marginBottom: 0, thickness: 0 })); + this.addCustomWidget(new RgthreeBetterButtonWidget("➕ Add Lora", (event, pos, node) => { + rgthreeApi.getLoras().then((loras) => { + showLoraChooser(event, (value) => { + var _b; + if (typeof value === "string") { + if (value.includes("Power Lora Chooser")) { + } + else if (value !== "NONE") { + this.addNewLoraWidget(value); + const computed = this.computeSize(); + const tempHeight = (_b = this._tempHeight) !== null && _b !== void 0 ? _b : 15; + this.size[1] = Math.max(tempHeight, computed[1]); + this.setDirtyCanvas(true, true); + } + } + }, null, [...loras]); + }); + return true; + })); + } + getSlotInPosition(canvasX, canvasY) { + var _b; + const slot = super.getSlotInPosition(canvasX, canvasY); + if (!slot) { + let lastWidget = null; + for (const widget of this.widgets) { + if (!widget.last_y) + return; + if (canvasY > this.pos[1] + widget.last_y) { + lastWidget = widget; + continue; + } + break; + } + if ((_b = lastWidget === null || lastWidget === void 0 ? void 0 : lastWidget.name) === null || _b === void 0 ? void 0 : _b.startsWith("lora_")) { + return { widget: lastWidget, output: { type: "LORA WIDGET" } }; + } + } + return slot; + } + getSlotMenuOptions(slot) { + var _b, _c, _d, _e, _f, _g; + if ((_c = (_b = slot === null || slot === void 0 ? void 0 : slot.widget) === null || _b === void 0 ? void 0 : _b.name) === null || _c === void 0 ? void 0 : _c.startsWith("lora_")) { + const widget = slot.widget; + const index = this.widgets.indexOf(widget); + const canMoveUp = !!((_e = (_d = this.widgets[index - 1]) === null || _d === void 0 ? void 0 : _d.name) === null || _e === void 0 ? void 0 : _e.startsWith("lora_")); + const canMoveDown = !!((_g = (_f = this.widgets[index + 1]) === null || _f === void 0 ? void 0 : _f.name) === null || _g === void 0 ? void 0 : _g.startsWith("lora_")); + const menuItems = [ + { + content: `ℹ️ Show Info`, + callback: () => { + widget.showLoraInfoDialog(); + }, + }, + null, + { + content: `${widget.value.on ? "⚫" : "🟢"} Toggle ${widget.value.on ? "Off" : "On"}`, + callback: () => { + widget.value.on = !widget.value.on; + }, + }, + { + content: `⬆️ Move Up`, + disabled: !canMoveUp, + callback: () => { + moveArrayItem(this.widgets, widget, index - 1); + }, + }, + { + content: `⬇️ Move Down`, + disabled: !canMoveDown, + callback: () => { + moveArrayItem(this.widgets, widget, index + 1); + }, + }, + { + content: `🗑️ Remove`, + callback: () => { + removeArrayItem(this.widgets, widget); + }, + }, + ]; + let canvas = app.canvas; + new LiteGraph.ContextMenu(menuItems, { title: "LORA WIDGET", event: rgthree.lastAdjustedMouseEvent }, canvas.getCanvasWindow()); + return null; + } + return this.defaultGetSlotMenuOptions(slot); + } + refreshComboInNode(defs) { + rgthreeApi.getLoras(true); + } + hasLoraWidgets() { + var _b; + return !!((_b = this.widgets) === null || _b === void 0 ? void 0 : _b.find((w) => { var _b; return (_b = w.name) === null || _b === void 0 ? void 0 : _b.startsWith("lora_"); })); + } + allLorasState() { + var _b, _c, _d; + let allOn = true; + let allOff = true; + for (const widget of this.widgets) { + if ((_b = widget.name) === null || _b === void 0 ? void 0 : _b.startsWith("lora_")) { + const on = (_c = widget.value) === null || _c === void 0 ? void 0 : _c.on; + allOn = allOn && on === true; + allOff = allOff && on === false; + if (!allOn && !allOff) { + return null; + } + } + } + return allOn && ((_d = this.widgets) === null || _d === void 0 ? void 0 : _d.length) ? true : false; + } + toggleAllLoras() { + var _b; + const allOn = this.allLorasState(); + const toggledTo = !allOn ? true : false; + for (const widget of this.widgets) { + if ((_b = widget.name) === null || _b === void 0 ? void 0 : _b.startsWith("lora_")) { + widget.value.on = toggledTo; + } + } + } + static setUp(comfyClass, nodeData) { + RgthreeBaseServerNode.registerForOverride(comfyClass, nodeData, NODE_CLASS); + } + static onRegisteredForOverride(comfyClass, ctxClass) { + addConnectionLayoutSupport(NODE_CLASS, app, [ + ["Left", "Right"], + ["Right", "Left"], + ]); + setTimeout(() => { + NODE_CLASS.category = comfyClass.category; + }); + } + getHelp() { + return ` +

+ The ${this.type.replace("(rgthree)", "")} is a powerful node that condenses 100s of pixels + of functionality in a single, dynamic node that allows you to add loras, change strengths, + and quickly toggle on/off all without taking up half your screen. +

+
    +
  • + Add as many Lora's as you would like by clicking the "+ Add Lora" button. + There's no real limit! +

  • +
  • + Right-click on a Lora widget for special options to move the lora up or down + (no image affect, only presentational), toggle it on/off, or delete the row all together. +

  • +
  • +

    + Properties. You can change the following properties (by right-clicking + on the node, and select "Properties" or "Properties Panel" from the menu): +

    +
      +
    • + ${PROP_LABEL_SHOW_STRENGTHS} - Change between showing a single, simple + strength (which will be used for both model and clip), or a more advanced view with + both model and clip strengths being modifiable. +

    • +
    +
  • +
`; + } +} +_a = PROP_LABEL_SHOW_STRENGTHS_STATIC; +RgthreePowerLoraLoader.title = NodeTypesString.POWER_LORA_LOADER; +RgthreePowerLoraLoader.type = NodeTypesString.POWER_LORA_LOADER; +RgthreePowerLoraLoader.comfyClass = NodeTypesString.POWER_LORA_LOADER; +RgthreePowerLoraLoader[_a] = { + type: "combo", + values: [PROP_VALUE_SHOW_STRENGTHS_SINGLE, PROP_VALUE_SHOW_STRENGTHS_SEPARATE], +}; +class PowerLoraLoaderHeaderWidget extends RgthreeBaseWidget { + constructor(name = "PowerLoraLoaderHeaderWidget") { + super(name); + this.showModelAndClip = null; + this.value = { type: "PowerLoraLoaderHeaderWidget" }; + this.hitAreas = { + toggle: { bounds: [0, 0], onDown: this.onToggleDown }, + }; + } + draw(ctx, node, w, posY, height) { + if (!node.hasLoraWidgets()) { + return; + } + this.showModelAndClip = + node.properties[PROP_LABEL_SHOW_STRENGTHS] === PROP_VALUE_SHOW_STRENGTHS_SEPARATE; + const margin = 10; + const innerMargin = margin * 0.33; + const lowQuality = isLowQuality(); + const allLoraState = node.allLorasState(); + posY += 2; + const midY = posY + height * 0.5; + let posX = 10; + ctx.save(); + this.hitAreas.toggle.bounds = drawTogglePart(ctx, { posX, posY, height, value: allLoraState }); + if (!lowQuality) { + posX += this.hitAreas.toggle.bounds[1] + innerMargin; + ctx.globalAlpha = app.canvas.editor_alpha * 0.55; + ctx.fillStyle = LiteGraph.WIDGET_TEXT_COLOR; + ctx.textAlign = "left"; + ctx.textBaseline = "middle"; + ctx.fillText("Toggle All", posX, midY); + let rposX = node.size[0] - margin - innerMargin - innerMargin; + ctx.textAlign = "center"; + ctx.fillText(this.showModelAndClip ? "Clip" : "Strength", rposX - drawNumberWidgetPart.WIDTH_TOTAL / 2, midY); + if (this.showModelAndClip) { + rposX = rposX - drawNumberWidgetPart.WIDTH_TOTAL - innerMargin * 2; + ctx.fillText("Model", rposX - drawNumberWidgetPart.WIDTH_TOTAL / 2, midY); + } + } + ctx.restore(); + } + onToggleDown(event, pos, node) { + node.toggleAllLoras(); + this.cancelMouseDown(); + return true; + } +} +const DEFAULT_LORA_WIDGET_DATA = { + on: true, + lora: null, + strength: 1, + strengthTwo: null, +}; +class PowerLoraLoaderWidget extends RgthreeBaseWidget { + constructor(name) { + super(name); + this.haveMouseMovedStrength = false; + this.loraInfoPromise = null; + this.loraInfo = null; + this.showModelAndClip = null; + this.hitAreas = { + toggle: { bounds: [0, 0], onDown: this.onToggleDown }, + lora: { bounds: [0, 0], onDown: this.onLoraDown }, + strengthDec: { bounds: [0, 0], onDown: this.onStrengthDecDown }, + strengthVal: { bounds: [0, 0], onUp: this.onStrengthValUp }, + strengthInc: { bounds: [0, 0], onDown: this.onStrengthIncDown }, + strengthAny: { bounds: [0, 0], onMove: this.onStrengthAnyMove }, + strengthTwoDec: { bounds: [0, 0], onDown: this.onStrengthTwoDecDown }, + strengthTwoVal: { bounds: [0, 0], onUp: this.onStrengthTwoValUp }, + strengthTwoInc: { bounds: [0, 0], onDown: this.onStrengthTwoIncDown }, + strengthTwoAny: { bounds: [0, 0], onMove: this.onStrengthTwoAnyMove }, + }; + this._value = { + on: true, + lora: null, + strength: 1, + strengthTwo: null, + }; + } + set value(v) { + this._value = v; + if (typeof this._value !== "object") { + this._value = { ...DEFAULT_LORA_WIDGET_DATA }; + if (this.showModelAndClip) { + this._value.strengthTwo = this._value.strength; + } + } + this.getLoraInfo(); + } + get value() { + return this._value; + } + setLora(lora) { + this._value.lora = lora; + this.getLoraInfo(); + } + draw(ctx, node, w, posY, height) { + var _b, _c, _d, _e, _f, _g, _h, _j, _k, _l, _m, _o, _p; + let currentShowModelAndClip = node.properties[PROP_LABEL_SHOW_STRENGTHS] === PROP_VALUE_SHOW_STRENGTHS_SEPARATE; + if (this.showModelAndClip !== currentShowModelAndClip) { + let oldShowModelAndClip = this.showModelAndClip; + this.showModelAndClip = currentShowModelAndClip; + if (this.showModelAndClip) { + if (oldShowModelAndClip != null) { + this.value.strengthTwo = (_b = this.value.strength) !== null && _b !== void 0 ? _b : 1; + } + } + else { + this.value.strengthTwo = null; + this.hitAreas.strengthTwoDec.bounds = [0, -1]; + this.hitAreas.strengthTwoVal.bounds = [0, -1]; + this.hitAreas.strengthTwoInc.bounds = [0, -1]; + this.hitAreas.strengthTwoAny.bounds = [0, -1]; + } + } + ctx.save(); + const margin = 10; + const innerMargin = margin * 0.33; + const lowQuality = isLowQuality(); + const midY = posY + height * 0.5; + let posX = margin; + drawRoundedRectangle(ctx, { posX, posY, height, width: node.size[0] - margin * 2 }); + this.hitAreas.toggle.bounds = drawTogglePart(ctx, { posX, posY, height, value: this.value.on }); + posX += this.hitAreas.toggle.bounds[1] + innerMargin; + if (lowQuality) { + ctx.restore(); + return; + } + if (!this.value.on) { + ctx.globalAlpha = app.canvas.editor_alpha * 0.4; + } + ctx.fillStyle = LiteGraph.WIDGET_TEXT_COLOR; + let rposX = node.size[0] - margin - innerMargin - innerMargin; + const strengthValue = this.showModelAndClip + ? (_c = this.value.strengthTwo) !== null && _c !== void 0 ? _c : 1 + : (_d = this.value.strength) !== null && _d !== void 0 ? _d : 1; + let textColor = undefined; + if (((_e = this.loraInfo) === null || _e === void 0 ? void 0 : _e.strengthMax) != null && strengthValue > ((_f = this.loraInfo) === null || _f === void 0 ? void 0 : _f.strengthMax)) { + textColor = "#c66"; + } + else if (((_g = this.loraInfo) === null || _g === void 0 ? void 0 : _g.strengthMin) != null && strengthValue < ((_h = this.loraInfo) === null || _h === void 0 ? void 0 : _h.strengthMin)) { + textColor = "#c66"; + } + const [leftArrow, text, rightArrow] = drawNumberWidgetPart(ctx, { + posX: node.size[0] - margin - innerMargin - innerMargin, + posY, + height, + value: strengthValue, + direction: -1, + textColor, + }); + this.hitAreas.strengthDec.bounds = leftArrow; + this.hitAreas.strengthVal.bounds = text; + this.hitAreas.strengthInc.bounds = rightArrow; + this.hitAreas.strengthAny.bounds = [leftArrow[0], rightArrow[0] + rightArrow[1] - leftArrow[0]]; + rposX = leftArrow[0] - innerMargin; + if (this.showModelAndClip) { + rposX -= innerMargin; + this.hitAreas.strengthTwoDec.bounds = this.hitAreas.strengthDec.bounds; + this.hitAreas.strengthTwoVal.bounds = this.hitAreas.strengthVal.bounds; + this.hitAreas.strengthTwoInc.bounds = this.hitAreas.strengthInc.bounds; + this.hitAreas.strengthTwoAny.bounds = this.hitAreas.strengthAny.bounds; + let textColor = undefined; + if (((_j = this.loraInfo) === null || _j === void 0 ? void 0 : _j.strengthMax) != null && this.value.strength > ((_k = this.loraInfo) === null || _k === void 0 ? void 0 : _k.strengthMax)) { + textColor = "#c66"; + } + else if (((_l = this.loraInfo) === null || _l === void 0 ? void 0 : _l.strengthMin) != null && + this.value.strength < ((_m = this.loraInfo) === null || _m === void 0 ? void 0 : _m.strengthMin)) { + textColor = "#c66"; + } + const [leftArrow, text, rightArrow] = drawNumberWidgetPart(ctx, { + posX: rposX, + posY, + height, + value: (_o = this.value.strength) !== null && _o !== void 0 ? _o : 1, + direction: -1, + textColor, + }); + this.hitAreas.strengthDec.bounds = leftArrow; + this.hitAreas.strengthVal.bounds = text; + this.hitAreas.strengthInc.bounds = rightArrow; + this.hitAreas.strengthAny.bounds = [ + leftArrow[0], + rightArrow[0] + rightArrow[1] - leftArrow[0], + ]; + rposX = leftArrow[0] - innerMargin; + } + const infoIconSize = height * 0.66; + const infoWidth = infoIconSize + innerMargin + innerMargin; + if (this.hitAreas["info"]) { + rposX -= innerMargin; + drawInfoIcon(ctx, rposX - infoIconSize, posY + (height - infoIconSize) / 2, infoIconSize); + this.hitAreas.info.bounds = [rposX - infoIconSize, infoWidth]; + rposX = rposX - infoIconSize - innerMargin; + } + const loraWidth = rposX - posX; + ctx.textAlign = "left"; + ctx.textBaseline = "middle"; + const loraLabel = String(((_p = this.value) === null || _p === void 0 ? void 0 : _p.lora) || "None"); + ctx.fillText(fitString(ctx, loraLabel, loraWidth), posX, midY); + this.hitAreas.lora.bounds = [posX, loraWidth]; + posX += loraWidth + innerMargin; + ctx.globalAlpha = app.canvas.editor_alpha; + ctx.restore(); + } + serializeValue(serializedNode, widgetIndex) { + var _b; + const v = { ...this.value }; + if (!this.showModelAndClip) { + delete v.strengthTwo; + } + else { + this.value.strengthTwo = (_b = this.value.strengthTwo) !== null && _b !== void 0 ? _b : 1; + v.strengthTwo = this.value.strengthTwo; + } + return v; + } + onToggleDown(event, pos, node) { + this.value.on = !this.value.on; + this.cancelMouseDown(); + return true; + } + onInfoDown(event, pos, node) { + this.showLoraInfoDialog(); + } + onLoraDown(event, pos, node) { + showLoraChooser(event, (value) => { + if (typeof value === "string") { + this.value.lora = value; + this.loraInfo = null; + this.getLoraInfo(); + } + node.setDirtyCanvas(true, true); + }); + this.cancelMouseDown(); + } + onStrengthDecDown(event, pos, node) { + this.stepStrength(-1, false); + } + onStrengthIncDown(event, pos, node) { + this.stepStrength(1, false); + } + onStrengthTwoDecDown(event, pos, node) { + this.stepStrength(-1, true); + } + onStrengthTwoIncDown(event, pos, node) { + this.stepStrength(1, true); + } + onStrengthAnyMove(event, pos, node) { + this.doOnStrengthAnyMove(event, false); + } + onStrengthTwoAnyMove(event, pos, node) { + this.doOnStrengthAnyMove(event, true); + } + doOnStrengthAnyMove(event, isTwo = false) { + var _b; + if (event.deltaX) { + let prop = isTwo ? "strengthTwo" : "strength"; + this.haveMouseMovedStrength = true; + this.value[prop] = ((_b = this.value[prop]) !== null && _b !== void 0 ? _b : 1) + event.deltaX * 0.05; + } + } + onStrengthValUp(event, pos, node) { + this.doOnStrengthValUp(event, false); + } + onStrengthTwoValUp(event, pos, node) { + this.doOnStrengthValUp(event, true); + } + doOnStrengthValUp(event, isTwo = false) { + if (this.haveMouseMovedStrength) + return; + let prop = isTwo ? "strengthTwo" : "strength"; + const canvas = app.canvas; + canvas.prompt("Value", this.value[prop], (v) => (this.value[prop] = Number(v)), event); + } + onMouseUp(event, pos, node) { + super.onMouseUp(event, pos, node); + this.haveMouseMovedStrength = false; + } + showLoraInfoDialog() { + if (!this.value.lora || this.value.lora === "None") { + return; + } + const infoDialog = new RgthreeInfoDialog(this.value.lora).show(); + infoDialog.addEventListener("close", ((e) => { + if (e.detail.dirty) { + this.getLoraInfo(true); + } + })); + } + stepStrength(direction, isTwo = false) { + var _b; + let step = 0.05; + let prop = isTwo ? "strengthTwo" : "strength"; + let strength = ((_b = this.value[prop]) !== null && _b !== void 0 ? _b : 1) + step * direction; + this.value[prop] = Math.round(strength * 100) / 100; + } + getLoraInfo(force = false) { + if (!this.loraInfoPromise || force == true) { + let promise; + if (this.value.lora && this.value.lora != "None") { + promise = MODEL_INFO_SERVICE.getLora(this.value.lora, force, true); + } + else { + promise = Promise.resolve(null); + } + this.loraInfoPromise = promise.then((v) => (this.loraInfo = v)); + } + return this.loraInfoPromise; + } +} +const NODE_CLASS = RgthreePowerLoraLoader; +app.registerExtension({ + name: "rgthree.PowerLoraLoader", + async beforeRegisterNodeDef(nodeType, nodeData) { + if (nodeData.name === NODE_CLASS.type) { + NODE_CLASS.setUp(nodeType, nodeData); + } + }, +}); diff --git a/rgthree-comfy/web/comfyui/power_prompt.js b/rgthree-comfy/web/comfyui/power_prompt.js new file mode 100644 index 0000000000000000000000000000000000000000..c73cdfc59dda0739002e2ae00652c9c5bc3ec4e2 --- /dev/null +++ b/rgthree-comfy/web/comfyui/power_prompt.js @@ -0,0 +1,42 @@ +import { app } from "../../scripts/app.js"; +import { addConnectionLayoutSupport } from "./utils.js"; +import { PowerPrompt } from "./base_power_prompt.js"; +import { NodeTypesString } from "./constants.js"; +let nodeData = null; +app.registerExtension({ + name: "rgthree.PowerPrompt", + async beforeRegisterNodeDef(nodeType, passedNodeData, _app) { + if (passedNodeData.name.includes("Power Prompt") && passedNodeData.name.includes("rgthree")) { + nodeData = passedNodeData; + const onNodeCreated = nodeType.prototype.onNodeCreated; + nodeType.prototype.onNodeCreated = function () { + onNodeCreated ? onNodeCreated.apply(this, []) : undefined; + this.powerPrompt = new PowerPrompt(this, passedNodeData); + }; + addConnectionLayoutSupport(nodeType, app, [ + ["Left", "Right"], + ["Right", "Left"], + ]); + } + }, + async loadedGraphNode(node) { + if (node.type === NodeTypesString.POWER_PROMPT) { + setTimeout(() => { + if (node.outputs[0].type === "STRING") { + if (node.outputs[0].links) { + node.outputs[3].links = node.outputs[3].links || []; + for (const link of node.outputs[0].links) { + node.outputs[3].links.push(link); + app.graph.links[link].origin_slot = 3; + } + node.outputs[0].links = null; + } + node.outputs[0].type = nodeData.output[0]; + node.outputs[0].name = nodeData.output_name[0] || node.outputs[0].type; + node.outputs[0].color_on = undefined; + node.outputs[0].color_off = undefined; + } + }, 50); + } + }, +}); diff --git a/rgthree-comfy/web/comfyui/random_unmuter.js b/rgthree-comfy/web/comfyui/random_unmuter.js new file mode 100644 index 0000000000000000000000000000000000000000..171aee4805da9b429c10ccd0452b5a108325d855 --- /dev/null +++ b/rgthree-comfy/web/comfyui/random_unmuter.js @@ -0,0 +1,101 @@ +import { app } from "../../scripts/app.js"; +import { BaseAnyInputConnectedNode } from "./base_any_input_connected_node.js"; +import { NodeTypesString } from "./constants.js"; +import { rgthree } from "./rgthree.js"; +import { getConnectedInputNodesAndFilterPassThroughs } from "./utils.js"; +const MODE_MUTE = 2; +const MODE_ALWAYS = 0; +class RandomUnmuterNode extends BaseAnyInputConnectedNode { + constructor(title = RandomUnmuterNode.title) { + super(title); + this.comfyClass = NodeTypesString.RANDOM_UNMUTER; + this.modeOn = MODE_ALWAYS; + this.modeOff = MODE_MUTE; + this.tempEnabledNode = null; + this.processingQueue = false; + this.onQueueBound = this.onQueue.bind(this); + this.onQueueEndBound = this.onQueueEnd.bind(this); + this.onGraphtoPromptBound = this.onGraphtoPrompt.bind(this); + this.onGraphtoPromptEndBound = this.onGraphtoPromptEnd.bind(this); + rgthree.addEventListener("queue", this.onQueueBound); + rgthree.addEventListener("queue-end", this.onQueueEndBound); + rgthree.addEventListener("graph-to-prompt", this.onGraphtoPromptBound); + rgthree.addEventListener("graph-to-prompt-end", this.onGraphtoPromptEndBound); + this.onConstructed(); + } + onRemoved() { + rgthree.removeEventListener("queue", this.onQueueBound); + rgthree.removeEventListener("queue-end", this.onQueueEndBound); + rgthree.removeEventListener("graph-to-prompt", this.onGraphtoPromptBound); + rgthree.removeEventListener("graph-to-prompt-end", this.onGraphtoPromptEndBound); + } + onQueue(event) { + this.processingQueue = true; + } + onQueueEnd(event) { + this.processingQueue = false; + } + onGraphtoPrompt(event) { + if (!this.processingQueue) { + return; + } + this.tempEnabledNode = null; + const linkedNodes = getConnectedInputNodesAndFilterPassThroughs(this); + let allMuted = true; + if (linkedNodes.length) { + for (const node of linkedNodes) { + if (node.mode !== this.modeOff) { + allMuted = false; + break; + } + } + if (allMuted) { + this.tempEnabledNode = linkedNodes[Math.floor(Math.random() * linkedNodes.length)] || null; + if (this.tempEnabledNode) { + this.tempEnabledNode.mode = this.modeOn; + } + } + } + } + onGraphtoPromptEnd(event) { + if (this.tempEnabledNode) { + this.tempEnabledNode.mode = this.modeOff; + this.tempEnabledNode = null; + } + } + handleLinkedNodesStabilization(linkedNodes) { + } + getHelp() { + return ` +

+ Use this node to unmute on of its inputs randomly when the graph is queued (and, immediately + mute it back). +

+
    +
  • + NOTE: All input nodes MUST be muted to start; if not this node will not randomly unmute + another. (This is powerful, as the generated image can be dragged in and the chosen input + will already by unmuted and work w/o any further action.) +

  • +
  • + TIP: Connect a Repeater's output to this nodes input and place that Repeater on a group + without any other inputs, and it will mute/unmute the entire group. +

  • +
+ `; + } +} +RandomUnmuterNode.exposedActions = ["Mute all", "Enable all"]; +RandomUnmuterNode.type = NodeTypesString.RANDOM_UNMUTER; +RandomUnmuterNode.title = RandomUnmuterNode.type; +app.registerExtension({ + name: "rgthree.RandomUnmuter", + registerCustomNodes() { + RandomUnmuterNode.setUp(); + }, + loadedGraphNode(node) { + if (node.type == RandomUnmuterNode.title) { + node._tempWidth = node.size[0]; + } + }, +}); diff --git a/rgthree-comfy/web/comfyui/reroute.js b/rgthree-comfy/web/comfyui/reroute.js new file mode 100644 index 0000000000000000000000000000000000000000..316b64f67e9efaad746c04bce889fa96083ab8b0 --- /dev/null +++ b/rgthree-comfy/web/comfyui/reroute.js @@ -0,0 +1,979 @@ +var _a, _b; +import { app } from "../../scripts/app.js"; +import { getWidgetConfig, mergeIfValid, setWidgetConfig, } from "../../extensions/core/widgetInputs.js"; +import { rgthreeConfig } from "../../rgthree/config.js"; +import { rgthree } from "./rgthree.js"; +import { IoDirection, LAYOUT_CLOCKWISE, LAYOUT_LABEL_OPPOSITES, LAYOUT_LABEL_TO_DATA, addConnectionLayoutSupport, addMenuItem, getSlotLinks, isValidConnection, setConnectionsLayout, waitForCanvas, } from "./utils.js"; +import { SERVICE as KEY_EVENT_SERVICE } from "./services/key_events_services.js"; +import { wait } from "../../rgthree/common/shared_utils.js"; +import { RgthreeBaseVirtualNode } from "./base_node.js"; +import { NodeTypesString } from "./constants.js"; +const CONFIG_REROUTE = ((_a = rgthreeConfig === null || rgthreeConfig === void 0 ? void 0 : rgthreeConfig["nodes"]) === null || _a === void 0 ? void 0 : _a["reroute"]) || {}; +const CONFIG_FAST_REROUTE = CONFIG_REROUTE["fast_reroute"]; +const CONFIG_FAST_REROUTE_ENABLED = (_b = CONFIG_FAST_REROUTE["enabled"]) !== null && _b !== void 0 ? _b : false; +const CONFIG_KEY_CREATE_WHILE_LINKING = CONFIG_FAST_REROUTE["key_create_while_dragging_link"]; +const CONFIG_KEY_ROTATE = CONFIG_FAST_REROUTE["key_rotate"]; +const CONFIG_KEY_RESIZE = CONFIG_FAST_REROUTE["key_resize"]; +const CONFIG_KEY_MOVE = CONFIG_FAST_REROUTE["key_move"]; +const CONFIG_KEY_CXN_INPUT = CONFIG_FAST_REROUTE["key_connections_input"]; +const CONFIG_KEY_CXN_OUTPUT = CONFIG_FAST_REROUTE["key_connections_output"]; +let configWidth = Math.max(Math.round((Number(CONFIG_REROUTE["default_width"]) || 40) / 10) * 10, 10); +let configHeight = Math.max(Math.round((Number(CONFIG_REROUTE["default_height"]) || 30) / 10) * 10, 10); +while (configWidth * configHeight < 400) { + configWidth += 10; + configHeight += 10; +} +const configDefaultSize = [configWidth, configHeight]; +const configResizable = !!CONFIG_REROUTE["default_resizable"]; +let configLayout = CONFIG_REROUTE["default_layout"]; +if (!Array.isArray(configLayout)) { + configLayout = ["Left", "Right"]; +} +if (!LAYOUT_LABEL_TO_DATA[configLayout[0]]) { + configLayout[0] = "Left"; +} +if (!LAYOUT_LABEL_TO_DATA[configLayout[1]] || configLayout[0] == configLayout[1]) { + configLayout[1] = LAYOUT_LABEL_OPPOSITES[configLayout[0]]; +} +class RerouteService { + constructor() { + this.isFastLinking = false; + this.handledNewRerouteKeypress = false; + this.connectingData = null; + this.fastReroutesHistory = []; + this.handleLinkingKeydownBound = this.handleLinkingKeydown.bind(this); + this.handleLinkingKeyupBound = this.handleLinkingKeyup.bind(this); + if (CONFIG_FAST_REROUTE_ENABLED && (CONFIG_KEY_CREATE_WHILE_LINKING === null || CONFIG_KEY_CREATE_WHILE_LINKING === void 0 ? void 0 : CONFIG_KEY_CREATE_WHILE_LINKING.trim())) { + this.onCanvasSetUpListenerForLinking(); + } + } + async onCanvasSetUpListenerForLinking() { + const canvas = await waitForCanvas(); + const canvasProperty = true ? 'connecting_links' : 'connecting_node'; + canvas[`_${canvasProperty}`]; + const thisService = this; + Object.defineProperty(canvas, canvasProperty, { + get: function () { + return this[`_${canvasProperty}`]; + }, + set: function (value) { + var _a; + const isValNull = !value || !(value === null || value === void 0 ? void 0 : value.length); + const isPropNull = !this[`_${canvasProperty}`] || !((_a = this[`_${canvasProperty}`]) === null || _a === void 0 ? void 0 : _a.length); + const isStartingLinking = !isValNull && isPropNull; + const isStoppingLinking = !isPropNull && isValNull; + this[`_${canvasProperty}`] = value; + if (isStartingLinking) { + thisService.startingLinking(); + } + if (isStoppingLinking) { + thisService.stoppingLinking(); + thisService.connectingData = null; + } + }, + }); + } + startingLinking() { + this.isFastLinking = true; + KEY_EVENT_SERVICE.addEventListener("keydown", this.handleLinkingKeydownBound); + KEY_EVENT_SERVICE.addEventListener("keyup", this.handleLinkingKeyupBound); + } + stoppingLinking() { + this.isFastLinking = false; + this.fastReroutesHistory = []; + KEY_EVENT_SERVICE.removeEventListener("keydown", this.handleLinkingKeydownBound); + KEY_EVENT_SERVICE.removeEventListener("keyup", this.handleLinkingKeyupBound); + } + handleLinkingKeydown(event) { + if (!this.handledNewRerouteKeypress && + KEY_EVENT_SERVICE.areOnlyKeysDown(CONFIG_KEY_CREATE_WHILE_LINKING)) { + this.handledNewRerouteKeypress = true; + this.insertNewRerouteWhileLinking(); + } + } + handleLinkingKeyup(event) { + if (this.handledNewRerouteKeypress && + !KEY_EVENT_SERVICE.areOnlyKeysDown(CONFIG_KEY_CREATE_WHILE_LINKING)) { + this.handledNewRerouteKeypress = false; + } + } + getConnectingData() { + var _a, _b; + const oldCanvas = app.canvas; + if (oldCanvas.connecting_node && oldCanvas.connecting_slot != null && ((_a = oldCanvas.connecting_pos) === null || _a === void 0 ? void 0 : _a.length)) { + return { + node: oldCanvas.connecting_node, + input: oldCanvas.connecting_input, + output: oldCanvas.connecting_output, + slot: oldCanvas.connecting_slot, + pos: [...oldCanvas.connecting_pos], + }; + } + const canvas = app.canvas; + if ((_b = canvas.connecting_links) === null || _b === void 0 ? void 0 : _b.length) { + const link = canvas.connecting_links[0]; + return { + node: link.node, + input: link.input, + output: link.output, + slot: link.slot, + pos: [...link.pos], + }; + } + throw new Error("Error, handling linking keydown, but there's no link."); + } + setCanvasConnectingData(ctx) { + var _a, _b; + const oldCanvas = app.canvas; + if (oldCanvas.connecting_node && oldCanvas.connecting_slot != null && ((_a = oldCanvas.connecting_pos) === null || _a === void 0 ? void 0 : _a.length)) { + oldCanvas.connecting_node = ctx.node; + oldCanvas.connecting_input = ctx.input; + oldCanvas.connecting_output = ctx.output; + oldCanvas.connecting_slot = ctx.slot; + oldCanvas.connecting_pos = ctx.pos; + } + const canvas = app.canvas; + if ((_b = canvas.connecting_links) === null || _b === void 0 ? void 0 : _b.length) { + const link = canvas.connecting_links[0]; + link.node = ctx.node; + link.input = ctx.input; + link.output = ctx.output; + link.slot = ctx.slot; + link.pos = ctx.pos; + } + } + insertNewRerouteWhileLinking() { + var _a; + const canvas = app.canvas; + this.connectingData = this.getConnectingData(); + if (!this.connectingData) { + throw new Error("Error, handling linking keydown, but there's no link."); + } + const data = this.connectingData; + const node = LiteGraph.createNode("Reroute (rgthree)"); + const entry = { + node, + previous: { ...this.connectingData }, + current: undefined, + }; + this.fastReroutesHistory.push(entry); + let connectingDir = (_a = (data.input || data.output)) === null || _a === void 0 ? void 0 : _a.dir; + if (!connectingDir) { + connectingDir = data.input ? LiteGraph.LEFT : LiteGraph.RIGHT; + } + let newPos = canvas.convertEventToCanvasOffset({ + clientX: Math.round(canvas.last_mouse_position[0] / 10) * 10, + clientY: Math.round(canvas.last_mouse_position[1] / 10) * 10, + }); + entry.node.pos = newPos; + canvas.graph.add(entry.node); + canvas.selectNode(entry.node); + const distX = entry.node.pos[0] - data.pos[0]; + const distY = entry.node.pos[1] - data.pos[1]; + const layout = ["Left", "Right"]; + if (distX > 0 && Math.abs(distX) > Math.abs(distY)) { + layout[0] = data.output ? "Left" : "Right"; + layout[1] = LAYOUT_LABEL_OPPOSITES[layout[0]]; + node.pos[0] -= node.size[0] + 10; + node.pos[1] -= Math.round(node.size[1] / 2 / 10) * 10; + } + else if (distX < 0 && Math.abs(distX) > Math.abs(distY)) { + layout[0] = data.output ? "Right" : "Left"; + layout[1] = LAYOUT_LABEL_OPPOSITES[layout[0]]; + node.pos[1] -= Math.round(node.size[1] / 2 / 10) * 10; + } + else if (distY < 0 && Math.abs(distY) > Math.abs(distX)) { + layout[0] = data.output ? "Bottom" : "Top"; + layout[1] = LAYOUT_LABEL_OPPOSITES[layout[0]]; + node.pos[0] -= Math.round(node.size[0] / 2 / 10) * 10; + } + else if (distY > 0 && Math.abs(distY) > Math.abs(distX)) { + layout[0] = data.output ? "Top" : "Bottom"; + layout[1] = LAYOUT_LABEL_OPPOSITES[layout[0]]; + node.pos[0] -= Math.round(node.size[0] / 2 / 10) * 10; + node.pos[1] -= node.size[1] + 10; + } + setConnectionsLayout(entry.node, layout); + if (data.output) { + data.node.connect(data.slot, entry.node, 0); + data.node = entry.node; + data.output = entry.node.outputs[0]; + data.slot = 0; + data.pos = entry.node.getConnectionPos(false, 0); + } + else { + entry.node.connect(0, data.node, data.slot); + data.node = entry.node; + data.input = entry.node.inputs[0]; + data.slot = 0; + data.pos = entry.node.getConnectionPos(true, 0); + } + this.setCanvasConnectingData(data); + entry.current = { ...this.connectingData }; + app.graph.setDirtyCanvas(true, true); + } + handleMoveOrResizeNodeMaybeWhileDragging(node) { + const data = this.connectingData; + if (this.isFastLinking && node === (data === null || data === void 0 ? void 0 : data.node)) { + const entry = this.fastReroutesHistory[this.fastReroutesHistory.length - 1]; + if (entry) { + data.pos = entry.node.getConnectionPos(!!data.input, 0); + this.setCanvasConnectingData(data); + } + } + } + handleRemovedNodeMaybeWhileDragging(node) { + const currentEntry = this.fastReroutesHistory[this.fastReroutesHistory.length - 1]; + if ((currentEntry === null || currentEntry === void 0 ? void 0 : currentEntry.node) === node) { + this.setCanvasConnectingData(currentEntry.previous); + this.fastReroutesHistory.splice(this.fastReroutesHistory.length - 1, 1); + if (currentEntry.previous.node) { + app.canvas.selectNode(currentEntry.previous.node); + } + } + } +} +const SERVICE = new RerouteService(); +class RerouteNode extends RgthreeBaseVirtualNode { + constructor(title = RerouteNode.title) { + super(title); + this.comfyClass = NodeTypesString.REROUTE; + this.isVirtualNode = true; + this.hideSlotLabels = true; + this.schedulePromise = null; + this.defaultConnectionsLayout = Array.from(configLayout); + this.shortcuts = { + rotate: { keys: CONFIG_KEY_ROTATE, state: false }, + connection_input: { keys: CONFIG_KEY_CXN_INPUT, state: false }, + connection_output: { keys: CONFIG_KEY_CXN_OUTPUT, state: false }, + resize: { + keys: CONFIG_KEY_RESIZE, + state: false, + initialMousePos: [-1, -1], + initialNodeSize: [-1, -1], + initialNodePos: [-1, -1], + resizeOnSide: [-1, -1], + }, + move: { + keys: CONFIG_KEY_MOVE, + state: false, + initialMousePos: [-1, -1], + initialNodePos: [-1, -1], + }, + }; + this.onConstructed(); + } + onConstructed() { + var _a; + this.setResizable((_a = this.properties["resizable"]) !== null && _a !== void 0 ? _a : configResizable); + this.size = RerouteNode.size; + this.addInput("", "*"); + this.addOutput("", "*"); + setTimeout(() => this.applyNodeSize(), 20); + return super.onConstructed(); + } + configure(info) { + var _a, _b, _c; + if ((_a = info.outputs) === null || _a === void 0 ? void 0 : _a.length) { + info.outputs.length = 1; + } + if ((_b = info.inputs) === null || _b === void 0 ? void 0 : _b.length) { + info.inputs.length = 1; + } + super.configure(info); + this.configuring = true; + this.setResizable((_c = this.properties["resizable"]) !== null && _c !== void 0 ? _c : configResizable); + this.applyNodeSize(); + this.configuring = false; + } + setResizable(resizable) { + this.properties["resizable"] = !!resizable; + this.resizable = this.properties["resizable"]; + } + clone() { + const cloned = super.clone(); + cloned.inputs[0].type = "*"; + cloned.outputs[0].type = "*"; + return cloned; + } + onConnectionsChange(type, _slotIndex, connected, _link_info, _ioSlot) { + if (connected && type === LiteGraph.OUTPUT) { + const types = new Set(this.outputs[0].links.map((l) => app.graph.links[l].type).filter((t) => t !== "*")); + if (types.size > 1) { + const linksToDisconnect = []; + for (let i = 0; i < this.outputs[0].links.length - 1; i++) { + const linkId = this.outputs[0].links[i]; + const link = app.graph.links[linkId]; + linksToDisconnect.push(link); + } + for (const link of linksToDisconnect) { + const node = app.graph.getNodeById(link.target_id); + node.disconnectInput(link.target_slot); + } + } + } + this.scheduleStabilize(); + } + onDrawForeground(ctx, canvas) { + var _a, _b, _c, _d; + if ((_a = this.properties) === null || _a === void 0 ? void 0 : _a["showLabel"]) { + const low_quality = ((_b = canvas === null || canvas === void 0 ? void 0 : canvas.ds) === null || _b === void 0 ? void 0 : _b.scale) && canvas.ds.scale < 0.6; + if (low_quality || this.size[0] <= 10) { + return; + } + const fontSize = Math.min(14, (this.size[1] * 0.65) | 0); + ctx.save(); + ctx.fillStyle = "#888"; + ctx.font = `${fontSize}px Arial`; + ctx.textAlign = "center"; + ctx.textBaseline = "middle"; + ctx.fillText(String(this.title && this.title !== RerouteNode.title + ? this.title + : ((_d = (_c = this.outputs) === null || _c === void 0 ? void 0 : _c[0]) === null || _d === void 0 ? void 0 : _d.type) || ""), this.size[0] / 2, this.size[1] / 2, this.size[0] - 30); + ctx.restore(); + } + } + findInputSlot(name) { + return 0; + } + findOutputSlot(name) { + return 0; + } + disconnectOutput(slot, targetNode) { + return super.disconnectOutput(slot, targetNode); + } + disconnectInput(slot) { + var _a; + if (rgthree.replacingReroute != null && ((_a = this.inputs[0]) === null || _a === void 0 ? void 0 : _a.link)) { + const graph = app.graph; + const link = graph.links[this.inputs[0].link]; + const node = graph.getNodeById(link === null || link === void 0 ? void 0 : link.origin_id); + if (rgthree.replacingReroute !== (node === null || node === void 0 ? void 0 : node.id)) { + return false; + } + } + return super.disconnectInput(slot); + } + scheduleStabilize(ms = 64) { + if (!this.schedulePromise) { + this.schedulePromise = new Promise((resolve) => { + setTimeout(() => { + this.schedulePromise = null; + this.stabilize(); + resolve(); + }, ms); + }); + } + return this.schedulePromise; + } + stabilize() { + var _a, _b, _c, _d, _e, _f, _g, _h, _j, _k, _l; + if (this.configuring) { + return; + } + let currentNode = this; + let updateNodes = []; + let input = null; + let inputType = null; + let inputNode = null; + let inputNodeOutputSlot = null; + while (currentNode) { + updateNodes.unshift(currentNode); + const linkId = currentNode.inputs[0].link; + if (linkId !== null) { + const link = app.graph.links[linkId]; + const node = app.graph.getNodeById(link.origin_id); + if (!node) { + app.graph.removeLink(linkId); + currentNode = null; + break; + } + const type = node.constructor.type; + if (type === null || type === void 0 ? void 0 : type.includes("Reroute")) { + if (node === this) { + currentNode.disconnectInput(link.target_slot); + currentNode = null; + } + else { + currentNode = node; + } + } + else { + inputNode = node; + inputNodeOutputSlot = link.origin_slot; + input = (_a = node.outputs[inputNodeOutputSlot]) !== null && _a !== void 0 ? _a : null; + inputType = (_b = input === null || input === void 0 ? void 0 : input.type) !== null && _b !== void 0 ? _b : null; + break; + } + } + else { + currentNode = null; + break; + } + } + const nodes = [this]; + let outputNode = null; + let outputType = null; + let outputWidgetConfig = null; + let outputWidget = null; + while (nodes.length) { + currentNode = nodes.pop(); + const outputs = (currentNode.outputs ? currentNode.outputs[0].links : []) || []; + if (outputs.length) { + for (const linkId of outputs) { + const link = app.graph.links[linkId]; + if (!link) + continue; + const node = app.graph.getNodeById(link.target_id); + if (!node) + continue; + const type = node.constructor.type; + if (type === null || type === void 0 ? void 0 : type.includes("Reroute")) { + nodes.push(node); + updateNodes.push(node); + } + else { + const output = (_d = (_c = node.inputs) === null || _c === void 0 ? void 0 : _c[link.target_slot]) !== null && _d !== void 0 ? _d : null; + const nodeOutType = output === null || output === void 0 ? void 0 : output.type; + if (nodeOutType == null) { + console.warn(`[rgthree] Reroute - Connected node ${node.id} does not have type information for ` + + `slot ${link.target_slot}. Skipping connection enforcement, but something is odd ` + + `with that node.`); + } + else if (inputType && + inputType !== "*" && + nodeOutType !== "*" && + !isValidConnection(input, output)) { + console.warn(`[rgthree] Reroute - Disconnecting connected node's input (${node.id}.${link.target_slot}) (${node.type}) because its type (${String(nodeOutType)}) does not match the reroute type (${String(inputType)})`); + node.disconnectInput(link.target_slot); + } + else { + outputType = nodeOutType; + outputNode = node; + outputWidgetConfig = null; + outputWidget = null; + if (output === null || output === void 0 ? void 0 : output.widget) { + try { + const config = getWidgetConfig(output); + if (!outputWidgetConfig && config) { + outputWidgetConfig = (_e = config[1]) !== null && _e !== void 0 ? _e : {}; + outputType = config[0]; + if (!outputWidget) { + outputWidget = (_f = outputNode.widgets) === null || _f === void 0 ? void 0 : _f.find((w) => { var _a; return w.name === ((_a = output === null || output === void 0 ? void 0 : output.widget) === null || _a === void 0 ? void 0 : _a.name); }); + } + const merged = mergeIfValid(output, [config[0], outputWidgetConfig]); + if (merged.customConfig) { + outputWidgetConfig = merged.customConfig; + } + } + } + catch (e) { + console.error("[rgthree] Could not propagate widget infor for reroute; maybe ComfyUI updated?"); + outputWidgetConfig = null; + outputWidget = null; + } + } + } + } + } + } + else { + } + } + const displayType = inputType || outputType || "*"; + const color = LGraphCanvas.link_type_colors[displayType]; + for (const node of updateNodes) { + node.outputs[0].type = inputType || "*"; + node.__outputType = displayType; + node.outputs[0].name = (input === null || input === void 0 ? void 0 : input.name) || ""; + node.size = node.computeSize(); + (_h = (_g = node).applyNodeSize) === null || _h === void 0 ? void 0 : _h.call(_g); + for (const l of node.outputs[0].links || []) { + const link = app.graph.links[l]; + if (link) { + link.color = color; + } + } + try { + if (outputWidgetConfig && outputWidget && outputType) { + node.inputs[0].widget = { name: "value" }; + setWidgetConfig(node.inputs[0], [outputType !== null && outputType !== void 0 ? outputType : displayType, outputWidgetConfig], outputWidget); + } + else { + setWidgetConfig(node.inputs[0], null); + } + } + catch (e) { + console.error("[rgthree] Could not set widget config for reroute; maybe ComfyUI updated?"); + outputWidgetConfig = null; + outputWidget = null; + if ((_j = node.inputs[0]) === null || _j === void 0 ? void 0 : _j.widget) { + delete node.inputs[0].widget; + } + } + } + if (inputNode && inputNodeOutputSlot != null) { + const links = inputNode.outputs[inputNodeOutputSlot].links; + for (const l of links || []) { + const link = app.graph.links[l]; + if (link) { + link.color = color; + } + } + } + (_k = inputNode === null || inputNode === void 0 ? void 0 : inputNode.onConnectionsChainChange) === null || _k === void 0 ? void 0 : _k.call(inputNode); + (_l = outputNode === null || outputNode === void 0 ? void 0 : outputNode.onConnectionsChainChange) === null || _l === void 0 ? void 0 : _l.call(outputNode); + app.graph.setDirtyCanvas(true, true); + } + setSize(size) { + const oldSize = [...this.size]; + const newSize = [...size]; + super.setSize(newSize); + this.properties["size"] = [...this.size]; + this.stabilizeLayout(oldSize, newSize); + } + stabilizeLayout(oldSize, newSize) { + if (newSize[0] === 10 || newSize[1] === 10) { + const props = this.properties; + props["connections_layout"] = props["connections_layout"] || ["Left", "Right"]; + const layout = props["connections_layout"]; + props["connections_dir"] = props["connections_dir"] || [-1, -1]; + const dir = props["connections_dir"]; + if (oldSize[0] > 10 && newSize[0] === 10) { + dir[0] = LiteGraph.DOWN; + dir[1] = LiteGraph.UP; + if (layout[0] === "Bottom") { + layout[1] = "Top"; + } + else if (layout[1] === "Top") { + layout[0] = "Bottom"; + } + else { + layout[0] = "Top"; + layout[1] = "Bottom"; + dir[0] = LiteGraph.UP; + dir[1] = LiteGraph.DOWN; + } + this.setDirtyCanvas(true, true); + } + else if (oldSize[1] > 10 && newSize[1] === 10) { + dir[0] = LiteGraph.RIGHT; + dir[1] = LiteGraph.LEFT; + if (layout[0] === "Right") { + layout[1] = "Left"; + } + else if (layout[1] === "Left") { + layout[0] = "Right"; + } + else { + layout[0] = "Left"; + layout[1] = "Right"; + dir[0] = LiteGraph.LEFT; + dir[1] = LiteGraph.RIGHT; + } + this.setDirtyCanvas(true, true); + } + } + SERVICE.handleMoveOrResizeNodeMaybeWhileDragging(this); + } + applyNodeSize() { + this.properties["size"] = this.properties["size"] || RerouteNode.size; + this.properties["size"] = [ + Number(this.properties["size"][0]), + Number(this.properties["size"][1]), + ]; + this.size = this.properties["size"]; + app.graph.setDirtyCanvas(true, true); + } + rotate(degrees) { + const w = this.size[0]; + const h = this.size[1]; + this.properties["connections_layout"] = + this.properties["connections_layout"] || this.defaultConnectionsLayout; + const inputDirIndex = LAYOUT_CLOCKWISE.indexOf(this.properties["connections_layout"][0]); + const outputDirIndex = LAYOUT_CLOCKWISE.indexOf(this.properties["connections_layout"][1]); + if (degrees == 90 || degrees === -90) { + if (degrees === -90) { + this.properties["connections_layout"][0] = + LAYOUT_CLOCKWISE[(((inputDirIndex - 1) % 4) + 4) % 4]; + this.properties["connections_layout"][1] = + LAYOUT_CLOCKWISE[(((outputDirIndex - 1) % 4) + 4) % 4]; + } + else { + this.properties["connections_layout"][0] = + LAYOUT_CLOCKWISE[(((inputDirIndex + 1) % 4) + 4) % 4]; + this.properties["connections_layout"][1] = + LAYOUT_CLOCKWISE[(((outputDirIndex + 1) % 4) + 4) % 4]; + } + } + else if (degrees === 180) { + this.properties["connections_layout"][0] = + LAYOUT_CLOCKWISE[(((inputDirIndex + 2) % 4) + 4) % 4]; + this.properties["connections_layout"][1] = + LAYOUT_CLOCKWISE[(((outputDirIndex + 2) % 4) + 4) % 4]; + } + this.setSize([h, w]); + } + manuallyHandleMove(event) { + const shortcut = this.shortcuts.move; + if (shortcut.state) { + const diffX = Math.round((event.clientX - shortcut.initialMousePos[0]) / 10) * 10; + const diffY = Math.round((event.clientY - shortcut.initialMousePos[1]) / 10) * 10; + this.pos[0] = shortcut.initialNodePos[0] + diffX; + this.pos[1] = shortcut.initialNodePos[1] + diffY; + this.setDirtyCanvas(true, true); + SERVICE.handleMoveOrResizeNodeMaybeWhileDragging(this); + } + } + manuallyHandleResize(event) { + const shortcut = this.shortcuts.resize; + if (shortcut.state) { + let diffX = Math.round((event.clientX - shortcut.initialMousePos[0]) / 10) * 10; + let diffY = Math.round((event.clientY - shortcut.initialMousePos[1]) / 10) * 10; + diffX *= shortcut.resizeOnSide[0] === LiteGraph.LEFT ? -1 : 1; + diffY *= shortcut.resizeOnSide[1] === LiteGraph.UP ? -1 : 1; + const oldSize = [...this.size]; + this.setSize([ + Math.max(10, shortcut.initialNodeSize[0] + diffX), + Math.max(10, shortcut.initialNodeSize[1] + diffY), + ]); + if (shortcut.resizeOnSide[0] === LiteGraph.LEFT && oldSize[0] > 10) { + this.pos[0] = shortcut.initialNodePos[0] - diffX; + } + if (shortcut.resizeOnSide[1] === LiteGraph.UP && oldSize[1] > 10) { + this.pos[1] = shortcut.initialNodePos[1] - diffY; + } + this.setDirtyCanvas(true, true); + } + } + cycleConnection(ioDir) { + var _a, _b; + const props = this.properties; + props["connections_layout"] = props["connections_layout"] || ["Left", "Right"]; + const propIdx = ioDir == IoDirection.INPUT ? 0 : 1; + const oppositeIdx = propIdx ? 0 : 1; + let currentLayout = props["connections_layout"][propIdx]; + let oppositeLayout = props["connections_layout"][oppositeIdx]; + if (this.size[0] === 10 || this.size[1] === 10) { + props["connections_dir"] = props["connections_dir"] || [-1, -1]; + let currentDir = props["connections_dir"][propIdx]; + const options = this.size[0] === 10 + ? currentLayout === "Bottom" + ? [LiteGraph.DOWN, LiteGraph.RIGHT, LiteGraph.LEFT] + : [LiteGraph.UP, LiteGraph.LEFT, LiteGraph.RIGHT] + : currentLayout === "Right" + ? [LiteGraph.RIGHT, LiteGraph.DOWN, LiteGraph.UP] + : [LiteGraph.LEFT, LiteGraph.UP, LiteGraph.DOWN]; + let idx = options.indexOf(currentDir); + let next = (_a = options[idx + 1]) !== null && _a !== void 0 ? _a : options[0]; + this.properties["connections_dir"][propIdx] = next; + return; + } + let next = currentLayout; + do { + let idx = LAYOUT_CLOCKWISE.indexOf(next); + next = (_b = LAYOUT_CLOCKWISE[idx + 1]) !== null && _b !== void 0 ? _b : LAYOUT_CLOCKWISE[0]; + } while (next === oppositeLayout); + this.properties["connections_layout"][propIdx] = next; + this.setDirtyCanvas(true, true); + } + onMouseMove(event) { + if (this.shortcuts.move.state) { + const shortcut = this.shortcuts.move; + if (shortcut.initialMousePos[0] === -1) { + shortcut.initialMousePos[0] = event.clientX; + shortcut.initialMousePos[1] = event.clientY; + shortcut.initialNodePos[0] = this.pos[0]; + shortcut.initialNodePos[1] = this.pos[1]; + } + this.manuallyHandleMove(event); + } + else if (this.shortcuts.resize.state) { + const shortcut = this.shortcuts.resize; + if (shortcut.initialMousePos[0] === -1) { + shortcut.initialMousePos[0] = event.clientX; + shortcut.initialMousePos[1] = event.clientY; + shortcut.initialNodeSize[0] = this.size[0]; + shortcut.initialNodeSize[1] = this.size[1]; + shortcut.initialNodePos[0] = this.pos[0]; + shortcut.initialNodePos[1] = this.pos[1]; + const canvas = app.canvas; + const offset = canvas.convertEventToCanvasOffset(event); + shortcut.resizeOnSide[0] = this.pos[0] > offset[0] ? LiteGraph.LEFT : LiteGraph.RIGHT; + shortcut.resizeOnSide[1] = this.pos[1] > offset[1] ? LiteGraph.UP : LiteGraph.DOWN; + } + this.manuallyHandleResize(event); + } + } + onKeyDown(event) { + super.onKeyDown(event); + const canvas = app.canvas; + if (CONFIG_FAST_REROUTE_ENABLED) { + for (const [key, shortcut] of Object.entries(this.shortcuts)) { + if (!shortcut.state) { + const keys = KEY_EVENT_SERVICE.areOnlyKeysDown(shortcut.keys); + if (keys) { + shortcut.state = true; + if (key === "rotate") { + this.rotate(90); + } + else if (key.includes("connection")) { + this.cycleConnection(key.includes("input") ? IoDirection.INPUT : IoDirection.OUTPUT); + } + if (shortcut.initialMousePos) { + canvas.node_capturing_input = this; + } + } + } + } + } + } + onKeyUp(event) { + super.onKeyUp(event); + const canvas = app.canvas; + if (CONFIG_FAST_REROUTE_ENABLED) { + for (const [key, shortcut] of Object.entries(this.shortcuts)) { + if (shortcut.state) { + const keys = KEY_EVENT_SERVICE.areOnlyKeysDown(shortcut.keys); + if (!keys) { + shortcut.state = false; + if (shortcut.initialMousePos) { + shortcut.initialMousePos = [-1, -1]; + if ((canvas.node_capturing_input = this)) { + canvas.node_capturing_input = null; + } + this.setDirtyCanvas(true, true); + } + } + } + } + } + } + onDeselected() { + var _a; + (_a = super.onDeselected) === null || _a === void 0 ? void 0 : _a.call(this); + const canvas = app.canvas; + for (const [key, shortcut] of Object.entries(this.shortcuts)) { + shortcut.state = false; + if (shortcut.initialMousePos) { + shortcut.initialMousePos = [-1, -1]; + if ((canvas.node_capturing_input = this)) { + canvas.node_capturing_input = null; + } + this.setDirtyCanvas(true, true); + } + } + } + onRemoved() { + var _a; + (_a = super.onRemoved) === null || _a === void 0 ? void 0 : _a.call(this); + setTimeout(() => { + SERVICE.handleRemovedNodeMaybeWhileDragging(this); + }, 32); + } + getHelp() { + return ` +

+ Finally, a comfortable, powerful reroute node with true multi-direction and powerful + shortcuts to bring your workflow to the next level. +

+ + ${!CONFIG_FAST_REROUTE_ENABLED + ? `

Fast Shortcuts are currently disabled.` + : ` +

    +
  • + ${CONFIG_KEY_CREATE_WHILE_LINKING} Create a new reroute node while dragging + a link, connecting it to the link in the place and continuing the link. +

  • +
  • + ${CONFIG_KEY_ROTATE} Rotate the selected reroute node counter clockwise 90 + degrees. +

  • +
  • + ${CONFIG_KEY_RESIZE} Resize the selected reroute node from the nearest + corner by holding down and moving your mouse. +

  • +
  • + ${CONFIG_KEY_MOVE} Move the selected reroute node by holding down and + moving your mouse. +

  • +
  • + ${CONFIG_KEY_CXN_INPUT} Change the input layout/direction of the selected + reroute node. +

  • +
  • + ${CONFIG_KEY_CXN_OUTPUT} Change the output layout/direction of the selected + reroute node. +

  • +
+ `} +

+ To change, ${!CONFIG_FAST_REROUTE_ENABLED ? "enable" : "disable"} or configure sohrtcuts, + make a copy of + /custom_nodes/rgthree-comfy/rgthree_config.json.default to + /custom_nodes/rgthree-comfy/rgthree_config.json and configure under + nodes > reroute > fast_reroute. +

+ `; + } +} +RerouteNode.title = NodeTypesString.REROUTE; +RerouteNode.type = NodeTypesString.REROUTE; +RerouteNode.title_mode = LiteGraph.NO_TITLE; +RerouteNode.collapsable = false; +RerouteNode.layout_slot_offset = 5; +RerouteNode.size = configDefaultSize; +addMenuItem(RerouteNode, app, { + name: (node) => { var _a; return `${((_a = node.properties) === null || _a === void 0 ? void 0 : _a["showLabel"]) ? "Hide" : "Show"} Label/Title`; }, + property: "showLabel", + callback: async (node, value) => { + app.graph.setDirtyCanvas(true, true); + }, +}); +addMenuItem(RerouteNode, app, { + name: (node) => `${node.resizable ? "No" : "Allow"} Resizing`, + callback: (node) => { + node.setResizable(!node.resizable); + node.size[0] = Math.max(40, node.size[0]); + node.size[1] = Math.max(30, node.size[1]); + node.applyNodeSize(); + }, +}); +addMenuItem(RerouteNode, app, { + name: "Static Width", + property: "size", + subMenuOptions: (() => { + const options = []; + for (let w = 8; w > 0; w--) { + options.push(`${w * 10}`); + } + return options; + })(), + prepareValue: (value, node) => [Number(value), node.size[1]], + callback: (node) => { + node.setResizable(false); + node.applyNodeSize(); + }, +}); +addMenuItem(RerouteNode, app, { + name: "Static Height", + property: "size", + subMenuOptions: (() => { + const options = []; + for (let w = 8; w > 0; w--) { + options.push(`${w * 10}`); + } + return options; + })(), + prepareValue: (value, node) => [node.size[0], Number(value)], + callback: (node) => { + node.setResizable(false); + node.applyNodeSize(); + }, +}); +addConnectionLayoutSupport(RerouteNode, app, [ + ["Left", "Right"], + ["Left", "Top"], + ["Left", "Bottom"], + ["Right", "Left"], + ["Right", "Top"], + ["Right", "Bottom"], + ["Top", "Left"], + ["Top", "Right"], + ["Top", "Bottom"], + ["Bottom", "Left"], + ["Bottom", "Right"], + ["Bottom", "Top"], +], (node) => { + node.applyNodeSize(); +}); +addMenuItem(RerouteNode, app, { + name: "Rotate", + subMenuOptions: [ + "Rotate 90° Clockwise", + "Rotate 90° Counter-Clockwise", + "Rotate 180°", + null, + "Flip Horizontally", + "Flip Vertically", + ], + callback: (node_, value) => { + const node = node_; + if (value === null || value === void 0 ? void 0 : value.startsWith("Rotate 90° Clockwise")) { + node.rotate(90); + } + else if (value === null || value === void 0 ? void 0 : value.startsWith("Rotate 90° Counter-Clockwise")) { + node.rotate(-90); + } + else if (value === null || value === void 0 ? void 0 : value.startsWith("Rotate 180°")) { + node.rotate(180); + } + else { + const inputDirIndex = LAYOUT_CLOCKWISE.indexOf(node.properties["connections_layout"][0]); + const outputDirIndex = LAYOUT_CLOCKWISE.indexOf(node.properties["connections_layout"][1]); + if (value === null || value === void 0 ? void 0 : value.startsWith("Flip Horizontally")) { + if (["Left", "Right"].includes(node.properties["connections_layout"][0])) { + node.properties["connections_layout"][0] = + LAYOUT_CLOCKWISE[(((inputDirIndex + 2) % 4) + 4) % 4]; + } + if (["Left", "Right"].includes(node.properties["connections_layout"][1])) { + node.properties["connections_layout"][1] = + LAYOUT_CLOCKWISE[(((outputDirIndex + 2) % 4) + 4) % 4]; + } + } + else if (value === null || value === void 0 ? void 0 : value.startsWith("Flip Vertically")) { + if (["Top", "Bottom"].includes(node.properties["connections_layout"][0])) { + node.properties["connections_layout"][0] = + LAYOUT_CLOCKWISE[(((inputDirIndex + 2) % 4) + 4) % 4]; + } + if (["Top", "Bottom"].includes(node.properties["connections_layout"][1])) { + node.properties["connections_layout"][1] = + LAYOUT_CLOCKWISE[(((outputDirIndex + 2) % 4) + 4) % 4]; + } + } + } + }, +}); +addMenuItem(RerouteNode, app, { + name: "Clone New Reroute...", + subMenuOptions: ["Before", "After"], + callback: async (node, value) => { + const clone = node.clone(); + const pos = [...node.pos]; + if (value === "Before") { + clone.pos = [pos[0] - 20, pos[1] - 20]; + app.graph.add(clone); + await wait(); + const inputLinks = getSlotLinks(node.inputs[0]); + for (const inputLink of inputLinks) { + const link = inputLink.link; + const linkedNode = app.graph.getNodeById(link.origin_id); + if (linkedNode) { + linkedNode.connect(0, clone, 0); + } + } + clone.connect(0, node, 0); + } + else { + clone.pos = [pos[0] + 20, pos[1] + 20]; + app.graph.add(clone); + await wait(); + const outputLinks = getSlotLinks(node.outputs[0]); + node.connect(0, clone, 0); + for (const outputLink of outputLinks) { + const link = outputLink.link; + const linkedNode = app.graph.getNodeById(link.target_id); + if (linkedNode) { + clone.connect(0, linkedNode, link.target_slot); + } + } + } + }, +}); +app.registerExtension({ + name: "rgthree.Reroute", + registerCustomNodes() { + RerouteNode.setUp(); + }, +}); diff --git a/rgthree-comfy/web/comfyui/rgthree.css b/rgthree-comfy/web/comfyui/rgthree.css new file mode 100644 index 0000000000000000000000000000000000000000..e677579e9087bc473d7003f5b9c43c1b15de93b4 --- /dev/null +++ b/rgthree-comfy/web/comfyui/rgthree.css @@ -0,0 +1,615 @@ +@charset "UTF-8"; +.rgthree-top-messages-container { + position: fixed; + z-index: 9999; + top: 0; + left: 0; + width: 100%; + height: 0; + display: flex; + flex-direction: column; + align-items: center; + justify-content: start; +} + +.rgthree-top-messages-container > div { + position: relative; + height: fit-content; + padding: 4px; + margin-top: -100px; /* re-set by JS */ + opacity: 0; + transition: all 0.33s ease-in-out; + z-index: 3; +} + +.rgthree-top-messages-container > div:last-child { + z-index: 2; +} + +.rgthree-top-messages-container > div:not(.-show) { + z-index: 1; +} + +.rgthree-top-messages-container > div.-show { + opacity: 1; + margin-top: 0px !important; +} + +.rgthree-top-messages-container > div.-show { + opacity: 1; + transform: translateY(0%); +} + +.rgthree-top-messages-container > div > div { + position: relative; + background: #353535; + color: #fff; + display: flex; + flex-direction: row; + align-items: center; + justify-content: center; + height: fit-content; + box-shadow: 0 0 10px rgba(0, 0, 0, 0.88); + padding: 6px 12px; + border-radius: 4px; + font-family: Arial, sans-serif; + font-size: 14px; +} + +.rgthree-top-messages-container > div > div > span { + display: flex; + flex-direction: row; + align-items: center; + justify-content: center; +} + +.rgthree-top-messages-container > div > div > span svg { + width: 20px; + height: auto; + margin-right: 8px; +} + +.rgthree-top-messages-container > div > div > span svg.icon-checkmark { + fill: #2e9720; +} + +.rgthree-top-messages-container [type=warn]::before, +.rgthree-top-messages-container [type=success]::before { + content: "⚠️"; + display: inline-block; + flex: 0 0 auto; + font-size: 18px; + margin-right: 4px; + line-height: 1; +} + +.rgthree-top-messages-container [type=success]::before { + content: "🎉"; +} + +.rgthree-top-messages-container a { + cursor: pointer; + text-decoration: underline; + color: #fc0; + margin-left: 4px; + display: inline-block; + line-height: 1; +} + +.rgthree-top-messages-container a:hover { + color: #fc0; + text-decoration: none; +} + +/* Fix node selector being crazy long b/c of array types. */ +.litegraph.litesearchbox input, +.litegraph.litesearchbox select { + max-width: 250px; +} + +/* There's no reason for this z-index to be so high. It layers on top of things it shouldn't, + (like pythongssss' image gallery, the properties panel, etc.) */ +.comfy-multiline-input { + z-index: 1 !important; +} + +.comfy-multiline-input:focus { + z-index: 2 !important; +} + +.litegraph .dialog { + z-index: 3 !important; /* This is set to 1, but goes under the multi-line inputs, so bump it. */ +} + +:not(#fakeid) .rgthree-button-reset { + position: relative; + appearance: none; + cursor: pointer; + border: 0; + background: transparent; + color: inherit; + padding: 0; + margin: 0; +} + +:not(#fakeid) .rgthree-button { + --padding-top: 7px; + --padding-bottom: 9px; + --padding-x: 16px; + position: relative; + cursor: pointer; + border: 0; + border-radius: 0.25rem; + background: rgba(0, 0, 0, 0.5); + color: white; + font-family: system-ui, sans-serif; + font-size: 1rem; + line-height: 1; + white-space: nowrap; + text-decoration: none; + margin: 0.25rem; + box-shadow: 0px 0px 2px rgb(0, 0, 0); + background: #212121; + transition: all 0.1s ease-in-out; + padding: var(--padding-top) var(--padding-x) var(--padding-bottom); + display: inline-flex; + flex-direction: row; + align-items: center; + justify-content: center; +} +:not(#fakeid) .rgthree-button::before, :not(#fakeid) .rgthree-button::after { + content: ""; + display: block; + position: absolute; + border-radius: 0.25rem; + left: 0; + top: 0; + width: 100%; + height: 100%; + box-shadow: inset 1px 1px 0px rgba(255, 255, 255, 0.12), inset -1px -1px 0px rgba(0, 0, 0, 0.75); + background: linear-gradient(to bottom, rgba(255, 255, 255, 0.06), rgba(0, 0, 0, 0.15)); + mix-blend-mode: screen; +} +:not(#fakeid) .rgthree-button::after { + mix-blend-mode: multiply; +} +:not(#fakeid) .rgthree-button:hover { + background: #303030; +} +:not(#fakeid) .rgthree-button:active { + box-shadow: 0px 0px 0px rgba(0, 0, 0, 0); + background: #121212; + padding: calc(var(--padding-top) + 1px) calc(var(--padding-x) - 1px) calc(var(--padding-bottom) - 1px) calc(var(--padding-x) + 1px); +} +:not(#fakeid) .rgthree-button:active::before, :not(#fakeid) .rgthree-button:active::after { + box-shadow: 1px 1px 0px rgba(255, 255, 255, 0.15), inset 1px 1px 0px rgba(0, 0, 0, 0.5), inset 1px 3px 5px rgba(0, 0, 0, 0.33); +} +:not(#fakeid) .rgthree-button.-blue { + background: #346599 !important; +} +:not(#fakeid) .rgthree-button.-blue:hover { + background: #3b77b8 !important; +} +:not(#fakeid) .rgthree-button.-blue:active { + background: #1d5086 !important; +} +:not(#fakeid) .rgthree-button.-green { + background: linear-gradient(to bottom, rgba(255, 255, 255, 0.06), rgba(0, 0, 0, 0.15)), #14580b; +} +:not(#fakeid) .rgthree-button.-green:hover { + background: linear-gradient(to bottom, rgba(255, 255, 255, 0.06), rgba(0, 0, 0, 0.15)), #1a6d0f; +} +:not(#fakeid) .rgthree-button.-green:active { + background: linear-gradient(to bottom, rgba(0, 0, 0, 0.15), rgba(255, 255, 255, 0.06)), #0f3f09; +} +:not(#fakeid) .rgthree-button[disabled] { + box-shadow: none; + background: #666 !important; + color: #aaa; + pointer-events: none; +} +:not(#fakeid) .rgthree-button[disabled]::before, :not(#fakeid) .rgthree-button[disabled]::after { + display: none; +} + +.rgthree-dialog { + outline: 0; + border: 0; + border-radius: 6px; + background: #414141; + color: #fff; + box-shadow: inset 1px 1px 0px rgba(255, 255, 255, 0.05), inset -1px -1px 0px rgba(0, 0, 0, 0.5), 2px 2px 20px rgb(0, 0, 0); + max-width: 800px; + box-sizing: border-box; + font-family: "Segoe UI", Tahoma, Geneva, Verdana, sans-serif; + font-size: 1rem; + padding: 0; + max-height: calc(100% - 32px); +} +.rgthree-dialog *, .rgthree-dialog *::before, .rgthree-dialog *::after { + box-sizing: inherit; +} + +.rgthree-dialog-container > * { + padding: 8px 16px; +} +.rgthree-dialog-container > *:first-child { + padding-top: 16px; +} +.rgthree-dialog-container > *:last-child { + padding-bottom: 16px; +} + +.rgthree-dialog.-iconed::after { + content: ""; + font-size: 276px; + position: absolute; + right: 0px; + bottom: 0px; + opacity: 0.15; + display: block; + width: 237px; + overflow: hidden; + height: 186px; + line-height: 1; + pointer-events: none; + z-index: -1; +} + +.rgthree-dialog.-iconed.-help::after { + content: "🛟"; +} + +.rgthree-dialog.-iconed.-settings::after { + content: "⚙️"; +} + +@media (max-width: 832px) { + .rgthree-dialog { + max-width: calc(100% - 32px); + } +} +.rgthree-dialog-container-title { + display: flex; + flex-direction: row; + align-items: center; + justify-content: start; +} + +.rgthree-dialog-container-title > svg:first-child { + width: 36px; + height: 36px; + margin-right: 16px; +} + +.rgthree-dialog-container-title h2 { + font-size: 1.375rem; + margin: 0; + font-weight: bold; +} + +.rgthree-dialog-container-title h2 small { + font-size: 0.8125rem; + font-weight: normal; + opacity: 0.75; +} + +.rgthree-dialog-container-content { + overflow: auto; + max-height: calc(100vh - 200px); /* Arbitrary height to copensate for margin, title, and footer.*/ +} + +.rgthree-dialog-container-content p { + font-size: 0.8125rem; + margin-top: 0; +} + +.rgthree-dialog-container-content ul li p { + margin-bottom: 4px; +} + +.rgthree-dialog-container-content ul li p + p { + margin-top: 0.5em; +} + +.rgthree-dialog-container-content ul li ul { + margin-top: 0.5em; + margin-bottom: 1em; +} + +.rgthree-dialog-container-content p code { + display: inline-block; + padding: 2px 4px; + margin: 0px 2px; + border: 1px solid rgba(255, 255, 255, 0.25); + border-radius: 3px; + background: rgba(255, 255, 255, 0.1); +} + +.rgthree-dialog-container-footer { + display: flex; + align-items: center; + justify-content: center; +} + +body.rgthree-dialog-open > *:not(.rgthree-dialog):not(.rgthree-top-messages-container) { + filter: blur(5px); +} + +.rgthree-menu { + list-style: none; + padding: 0; + margin: 0; + position: fixed; + z-index: 999999; + pointer-events: none; + opacity: 0; + transition: opacity 0.08s ease-in-out; + color: #dde; + background-color: #111; + font-size: 12px; + box-shadow: 0 0 10px black !important; +} +.rgthree-menu > li { + position: relative; + padding: 4px 6px; + z-index: 9999; + white-space: nowrap; +} +.rgthree-menu > li[role=button] { + background-color: var(--comfy-menu-bg) !important; + color: var(--input-text); + cursor: pointer; +} +.rgthree-menu > li[role=button]:hover { + filter: brightness(155%); +} +.rgthree-menu[state^=measuring] { + display: block; + opacity: 0; +} +.rgthree-menu[state=open] { + display: block; + opacity: 1; + pointer-events: all; +} + +.rgthree-top-menu { + box-sizing: border-box; + white-space: nowrap; + background: var(--content-bg); + color: var(--content-fg); + display: flex; + flex-direction: column; +} +.rgthree-top-menu * { + box-sizing: inherit; +} +.rgthree-top-menu menu { + list-style: none; + padding: 0; + margin: 0; +} +.rgthree-top-menu menu > li:not(#fakeid) { + list-style: none; + padding: 0; + margin: 0; +} +.rgthree-top-menu menu > li:not(#fakeid) > button { + cursor: pointer; + padding: 8px 12px 8px 8px; + width: 100%; + text-align: start; + display: flex; + flex-direction: row; + align-items: center; + justify-content: start; +} +.rgthree-top-menu menu > li:not(#fakeid) > button:hover { + background-color: var(--comfy-input-bg); +} +.rgthree-top-menu menu > li:not(#fakeid) > button svg { + height: 16px; + width: auto; + margin-inline-end: 0.6em; +} +.rgthree-top-menu menu > li:not(#fakeid) > button svg.github-star { + fill: rgb(227, 179, 65); +} +.rgthree-top-menu menu > li:not(#fakeid).rgthree-message { + min-height: 32px; +} +.rgthree-top-menu menu > li:not(#fakeid).rgthree-message > span { + padding: 8px 12px; + display: block; + width: 100%; + text-align: center; + font-style: italic; + font-size: 12px; +} + +.rgthree-dialog.-settings { + width: 100%; +} + +.rgthree-dialog.-settings fieldset { + border: 1px solid rgba(255, 255, 255, 0.25); + padding: 0 12px 8px; + margin-bottom: 16px; +} + +.rgthree-dialog.-settings fieldset > legend { + margin-left: 8px; + padding: 0 8px; + opacity: 0.5; +} + +.rgthree-dialog.-settings .formrow { + display: flex; + flex-direction: column; +} + +.rgthree-dialog.-settings .formrow + .formrow { + border-top: 1px solid rgba(255, 255, 255, 0.25); +} + +.rgthree-dialog.-settings .fieldrow { + display: flex; + flex-direction: row; +} + +.rgthree-dialog.-settings .fieldrow > label { + flex: 1 1 auto; + user-select: none; + padding: 8px 12px 12px; +} + +.rgthree-dialog.-settings .fieldrow > label span { + font-weight: bold; +} + +.rgthree-dialog.-settings .fieldrow > label small { + display: block; + margin-top: 4px; + font-size: 0.6875rem; + opacity: 0.75; + padding-left: 16px; +} + +.rgthree-dialog.-settings .fieldrow ~ .fieldrow { + font-size: 0.9rem; + border-top: 1px dotted rgba(255, 255, 255, 0.25); +} + +.rgthree-dialog.-settings .fieldrow ~ .fieldrow label { + padding-left: 28px; +} + +.rgthree-dialog.-settings .fieldrow:first-child:not(.-checked) ~ .fieldrow { + display: none; +} + +.rgthree-dialog.-settings .fieldrow:hover { + background: rgba(255, 255, 255, 0.1); +} + +.rgthree-dialog.-settings .fieldrow ~ .fieldrow span { + font-weight: normal; +} + +.rgthree-dialog.-settings .fieldrow > .fieldrow-value { + display: flex; + align-items: center; + justify-content: end; + flex: 0 0 auto; + width: 50%; + max-width: 230px; +} + +.rgthree-dialog.-settings .fieldrow.-type-boolean > .fieldrow-value { + max-width: 64px; +} + +.rgthree-dialog.-settings .fieldrow.-type-number input { + width: 48px; + text-align: right; +} + +.rgthree-dialog.-settings .fieldrow input[type=checkbox] { + width: 24px; + height: 24px; + cursor: pointer; +} + +.rgthree-comfyui-settings-row div { + display: flex; + flex-direction: row; + align-items: center; + justify-content: end; +} + +.rgthree-comfyui-settings-row div svg { + width: 36px; + height: 36px; + margin-right: 16px; +} + +.litegraph.litecontextmenu .litemenu-title .rgthree-contextmenu-title-rgthree-comfy, +.litegraph.litecontextmenu .litemenu-entry.rgthree-contextmenu-item { + display: flex; + flex-direction: row; + align-items: center; + justify-content: start; +} + +.litegraph.litecontextmenu .litemenu-title .rgthree-contextmenu-title-rgthree-comfy svg, +.litegraph.litecontextmenu .litemenu-entry.rgthree-contextmenu-item svg { + fill: currentColor; + width: auto; + height: 16px; + margin-right: 6px; +} + +.litegraph.litecontextmenu .litemenu-entry.rgthree-contextmenu-item svg.github-star { + fill: rgb(227, 179, 65); +} + +.litegraph.litecontextmenu .litemenu-title .rgthree-contextmenu-title-rgthree-comfy, +.litegraph.litecontextmenu .litemenu-entry.rgthree-contextmenu-label { + color: #dde; + background-color: #212121 !important; + margin: 0; + padding: 2px; + cursor: default; + opacity: 1; + padding: 4px; + font-weight: bold; +} + +.litegraph.litecontextmenu .litemenu-title .rgthree-contextmenu-title-rgthree-comfy { + font-size: 1.1em; + color: #fff; + background-color: #090909 !important; + justify-content: center; + padding: 4px 8px; +} + +rgthree-progress-bar { + display: block; + position: relative; + z-index: 999; + top: 0; + left: 0; + height: 14px; + font-size: 10px; + width: 100%; + overflow: hidden; + box-shadow: 0px 0px 3px rgba(0, 0, 0, 0.25); + box-shadow: inset 0px -1px 0px rgba(0, 0, 0, 0.25), 0px 1px 0px rgba(255, 255, 255, 0.125); +} + +* ~ rgthree-progress-bar, +.comfyui-body-bottom rgthree-progress-bar { + box-shadow: 0px -1px 0px rgb(0, 0, 0), inset 0px 1px 0px rgba(255, 255, 255, 0.15), inset 0px -1px 0px rgba(0, 0, 0, 0.25), 0px 1px 0px rgba(255, 255, 255, 0.125); +} + +body:not([style*=grid]) rgthree-progress-bar { + position: fixed; + top: 0px; + bottom: auto; +} +body:not([style*=grid]) rgthree-progress-bar.rgthree-pos-bottom { + top: auto; + bottom: 0px; +} + +.rgthree-debug-keydowns { + display: block; + position: fixed; + z-index: 999; + top: 3px; + right: 8px; + font-size: 10px; + color: #fff; + font-family: sans-serif; +} diff --git a/rgthree-comfy/web/comfyui/rgthree.js b/rgthree-comfy/web/comfyui/rgthree.js new file mode 100644 index 0000000000000000000000000000000000000000..e2ec2319340e4467343af5ad635e8204a3cc5d26 --- /dev/null +++ b/rgthree-comfy/web/comfyui/rgthree.js @@ -0,0 +1,705 @@ +import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js"; +import { SERVICE as CONFIG_SERVICE } from "./services/config_service.js"; +import { fixBadLinks } from "../../rgthree/common/link_fixer.js"; +import { injectCss, wait } from "../../rgthree/common/shared_utils.js"; +import { replaceNode, waitForCanvas, waitForGraph } from "./utils.js"; +import { NodeTypesString, addRgthree, getNodeTypeStrings } from "./constants.js"; +import { RgthreeProgressBar } from "../../rgthree/common/progress_bar.js"; +import { RgthreeConfigDialog } from "./config.js"; +import { iconGear, iconNode, iconReplace, iconStarFilled, logoRgthree, } from "../../rgthree/common/media/svgs.js"; +import { query, queryOne } from "../../rgthree/common/utils_dom.js"; +export var LogLevel; +(function (LogLevel) { + LogLevel[LogLevel["IMPORTANT"] = 1] = "IMPORTANT"; + LogLevel[LogLevel["ERROR"] = 2] = "ERROR"; + LogLevel[LogLevel["WARN"] = 3] = "WARN"; + LogLevel[LogLevel["INFO"] = 4] = "INFO"; + LogLevel[LogLevel["DEBUG"] = 5] = "DEBUG"; + LogLevel[LogLevel["DEV"] = 6] = "DEV"; +})(LogLevel || (LogLevel = {})); +const LogLevelKeyToLogLevel = { + IMPORTANT: LogLevel.IMPORTANT, + ERROR: LogLevel.ERROR, + WARN: LogLevel.WARN, + INFO: LogLevel.INFO, + DEBUG: LogLevel.DEBUG, + DEV: LogLevel.DEV, +}; +const LogLevelToMethod = { + [LogLevel.IMPORTANT]: "log", + [LogLevel.ERROR]: "error", + [LogLevel.WARN]: "warn", + [LogLevel.INFO]: "info", + [LogLevel.DEBUG]: "log", + [LogLevel.DEV]: "log", +}; +const LogLevelToCSS = { + [LogLevel.IMPORTANT]: "font-weight: bold; color: blue;", + [LogLevel.ERROR]: "", + [LogLevel.WARN]: "", + [LogLevel.INFO]: "font-style: italic; color: blue;", + [LogLevel.DEBUG]: "font-style: italic; color: #444;", + [LogLevel.DEV]: "color: #004b68;", +}; +let GLOBAL_LOG_LEVEL = LogLevel.ERROR; +const INVOKE_EXTENSIONS_BLOCKLIST = [ + { + name: "Comfy.WidgetInputs", + reason: "Major conflict with rgthree-comfy nodes' inputs causing instability and " + + "repeated link disconnections.", + }, + { + name: "efficiency.widgethider", + reason: "Overrides value getter before widget getter is prepared. Can be lifted if/when " + + "https://github.com/jags111/efficiency-nodes-comfyui/pull/203 is pulled.", + }, +]; +class Logger { + log(level, message, ...args) { + var _a; + const [n, v] = this.logParts(level, message, ...args); + (_a = console[n]) === null || _a === void 0 ? void 0 : _a.call(console, ...v); + } + logParts(level, message, ...args) { + if (level <= GLOBAL_LOG_LEVEL) { + const css = LogLevelToCSS[level] || ""; + if (level === LogLevel.DEV) { + message = `🔧 ${message}`; + } + return [LogLevelToMethod[level], [`%c${message}`, css, ...args]]; + } + return ["none", []]; + } +} +class LogSession { + constructor(name) { + this.name = name; + this.logger = new Logger(); + this.logsCache = {}; + } + logParts(level, message, ...args) { + message = `${this.name || ""}${message ? " " + message : ""}`; + return this.logger.logParts(level, message, ...args); + } + logPartsOnceForTime(level, time, message, ...args) { + message = `${this.name || ""}${message ? " " + message : ""}`; + const cacheKey = `${level}:${message}`; + const cacheEntry = this.logsCache[cacheKey]; + const now = +new Date(); + if (cacheEntry && cacheEntry.lastShownTime + time > now) { + return ["none", []]; + } + const parts = this.logger.logParts(level, message, ...args); + if (console[parts[0]]) { + this.logsCache[cacheKey] = this.logsCache[cacheKey] || {}; + this.logsCache[cacheKey].lastShownTime = now; + } + return parts; + } + debugParts(message, ...args) { + return this.logParts(LogLevel.DEBUG, message, ...args); + } + infoParts(message, ...args) { + return this.logParts(LogLevel.INFO, message, ...args); + } + warnParts(message, ...args) { + return this.logParts(LogLevel.WARN, message, ...args); + } + newSession(name) { + return new LogSession(`${this.name}${name}`); + } +} +class Rgthree extends EventTarget { + constructor() { + var _a, _b, _c, _d; + super(); + this.api = api; + this.settingsDialog = null; + this.progressBarEl = null; + this.queueNodeIds = null; + this.logger = new LogSession("[rgthree]"); + this.monitorBadLinksAlerted = false; + this.monitorLinkTimeout = null; + this.processingQueue = false; + this.loadingApiJson = false; + this.replacingReroute = null; + this.processingMouseDown = false; + this.processingMouseUp = false; + this.processingMouseMove = false; + this.lastAdjustedMouseEvent = null; + this.canvasCurrentlyCopyingToClipboard = false; + this.canvasCurrentlyCopyingToClipboardWithMultipleNodes = false; + this.initialGraphToPromptSerializedWorkflowBecauseComfyUIBrokeStuff = null; + this.elDebugKeydowns = null; + this.isMac = !!(((_a = navigator.platform) === null || _a === void 0 ? void 0 : _a.toLocaleUpperCase().startsWith("MAC")) || + ((_c = (_b = navigator.userAgentData) === null || _b === void 0 ? void 0 : _b.platform) === null || _c === void 0 ? void 0 : _c.toLocaleUpperCase().startsWith("MAC"))); + const logLevel = (_d = LogLevelKeyToLogLevel[CONFIG_SERVICE.getConfigValue("log_level")]) !== null && _d !== void 0 ? _d : GLOBAL_LOG_LEVEL; + this.setLogLevel(logLevel); + this.initializeGraphAndCanvasHooks(); + this.initializeComfyUIHooks(); + this.initializeContextMenu(); + this.rgthreeCssPromise = injectCss("extensions/rgthree-comfy/rgthree.css"); + this.initializeProgressBar(); + CONFIG_SERVICE.addEventListener("config-change", ((e) => { + var _a, _b; + if ((_b = (_a = e.detail) === null || _a === void 0 ? void 0 : _a.key) === null || _b === void 0 ? void 0 : _b.includes("features.progress_bar")) { + this.initializeProgressBar(); + } + })); + } + async initializeProgressBar() { + var _a; + if (CONFIG_SERVICE.getConfigValue("features.progress_bar.enabled")) { + await this.rgthreeCssPromise; + if (!this.progressBarEl) { + this.progressBarEl = RgthreeProgressBar.create(); + this.progressBarEl.setAttribute("title", "Progress Bar by rgthree. right-click for rgthree menu."); + this.progressBarEl.addEventListener("contextmenu", async (e) => { + e.stopPropagation(); + e.preventDefault(); + }); + this.progressBarEl.addEventListener("pointerdown", async (e) => { + var _a; + LiteGraph.closeAllContextMenus(); + if (e.button == 2) { + const canvas = await waitForCanvas(); + new LiteGraph.ContextMenu(this.getRgthreeContextMenuItems(), { + title: `
${logoRgthree} rgthree-comfy
`, + left: e.clientX, + top: 5, + }, canvas.getCanvasWindow()); + return; + } + if (e.button == 0) { + const nodeId = (_a = this.progressBarEl) === null || _a === void 0 ? void 0 : _a.currentNodeId; + if (nodeId) { + const [canvas, graph] = await Promise.all([waitForCanvas(), waitForGraph()]); + const node = graph.getNodeById(Number(nodeId)); + if (node) { + canvas.centerOnNode(node); + e.stopPropagation(); + e.preventDefault(); + } + } + return; + } + }); + } + const isUpdatedComfyBodyClasses = !!queryOne(".comfyui-body-top"); + const position = CONFIG_SERVICE.getConfigValue("features.progress_bar.position"); + this.progressBarEl.classList.toggle("rgthree-pos-bottom", position === "bottom"); + if (isUpdatedComfyBodyClasses) { + if (position === "bottom") { + queryOne(".comfyui-body-bottom").appendChild(this.progressBarEl); + } + else { + queryOne(".comfyui-body-top").appendChild(this.progressBarEl); + } + } + else { + document.body.appendChild(this.progressBarEl); + } + const height = CONFIG_SERVICE.getConfigValue("features.progress_bar.height") || 14; + this.progressBarEl.style.height = `${height}px`; + const fontSize = Math.max(10, Number(height) - 10); + this.progressBarEl.style.fontSize = `${fontSize}px`; + this.progressBarEl.style.fontWeight = fontSize <= 12 ? "bold" : "normal"; + } + else { + (_a = this.progressBarEl) === null || _a === void 0 ? void 0 : _a.remove(); + } + } + async initializeGraphAndCanvasHooks() { + const rgthree = this; + const graphSerialize = LGraph.prototype.serialize; + LGraph.prototype.serialize = function () { + const response = graphSerialize.apply(this, [...arguments]); + rgthree.initialGraphToPromptSerializedWorkflowBecauseComfyUIBrokeStuff = response; + return response; + }; + const processMouseDown = LGraphCanvas.prototype.processMouseDown; + LGraphCanvas.prototype.processMouseDown = function (e) { + rgthree.processingMouseDown = true; + const returnVal = processMouseDown.apply(this, [...arguments]); + rgthree.dispatchCustomEvent("on-process-mouse-down", { originalEvent: e }); + rgthree.processingMouseDown = false; + return returnVal; + }; + const adjustMouseEvent = LGraphCanvas.prototype.adjustMouseEvent; + LGraphCanvas.prototype.adjustMouseEvent = function (e) { + adjustMouseEvent.apply(this, [...arguments]); + rgthree.lastAdjustedMouseEvent = e; + }; + const copyToClipboard = LGraphCanvas.prototype.copyToClipboard; + LGraphCanvas.prototype.copyToClipboard = function (nodes) { + rgthree.canvasCurrentlyCopyingToClipboard = true; + rgthree.canvasCurrentlyCopyingToClipboardWithMultipleNodes = + Object.values(nodes || this.selected_nodes || []).length > 1; + copyToClipboard.apply(this, [...arguments]); + rgthree.canvasCurrentlyCopyingToClipboard = false; + rgthree.canvasCurrentlyCopyingToClipboardWithMultipleNodes = false; + }; + const onGroupAdd = LGraphCanvas.onGroupAdd; + LGraphCanvas.onGroupAdd = function (...args) { + const graph = app.graph; + onGroupAdd.apply(this, [...args]); + LGraphCanvas.onShowPropertyEditor({}, null, null, null, graph._groups[graph._groups.length - 1]); + }; + } + async invokeExtensionsAsync(method, ...args) { + var _a; + const comfyapp = app; + if (CONFIG_SERVICE.getConfigValue("features.invoke_extensions_async.node_created") === false) { + const [m, a] = this.logParts(LogLevel.INFO, `Skipping invokeExtensionsAsync for applicable rgthree-comfy nodes`); + (_a = console[m]) === null || _a === void 0 ? void 0 : _a.call(console, ...a); + return Promise.resolve(); + } + return await Promise.all(comfyapp.extensions.map(async (ext) => { + var _a, _b; + if (ext === null || ext === void 0 ? void 0 : ext[method]) { + try { + const blocked = INVOKE_EXTENSIONS_BLOCKLIST.find((block) => ext.name.toLowerCase().startsWith(block.name.toLowerCase())); + if (blocked) { + const [n, v] = this.logger.logPartsOnceForTime(LogLevel.WARN, 5000, `Blocked extension '${ext.name}' method '${method}' for rgthree-nodes because: ${blocked.reason}`); + (_a = console[n]) === null || _a === void 0 ? void 0 : _a.call(console, ...v); + return Promise.resolve(); + } + return await ext[method](...args, comfyapp); + } + catch (error) { + const [n, v] = this.logParts(LogLevel.ERROR, `Error calling extension '${ext.name}' method '${method}' for rgthree-node.`, { error }, { extension: ext }, { args }); + (_b = console[n]) === null || _b === void 0 ? void 0 : _b.call(console, ...v); + } + } + })); + } + dispatchCustomEvent(event, detail) { + if (detail != null) { + return this.dispatchEvent(new CustomEvent(event, { detail })); + } + return this.dispatchEvent(new CustomEvent(event)); + } + async initializeContextMenu() { + const that = this; + setTimeout(async () => { + const getCanvasMenuOptions = LGraphCanvas.prototype.getCanvasMenuOptions; + LGraphCanvas.prototype.getCanvasMenuOptions = function (...args) { + let existingOptions = getCanvasMenuOptions.apply(this, [...args]); + const options = []; + options.push(null); + options.push(null); + options.push(null); + options.push({ + content: logoRgthree + `rgthree-comfy`, + className: "rgthree-contextmenu-item rgthree-contextmenu-main-item-rgthree-comfy", + submenu: { + options: that.getRgthreeContextMenuItems(), + }, + }); + options.push(null); + options.push(null); + let idx = null; + idx = idx || existingOptions.findIndex((o) => { var _a, _b; return (_b = (_a = o === null || o === void 0 ? void 0 : o.content) === null || _a === void 0 ? void 0 : _a.startsWith) === null || _b === void 0 ? void 0 : _b.call(_a, "Queue Group"); }) + 1; + idx = + idx || existingOptions.findIndex((o) => { var _a, _b; return (_b = (_a = o === null || o === void 0 ? void 0 : o.content) === null || _a === void 0 ? void 0 : _a.startsWith) === null || _b === void 0 ? void 0 : _b.call(_a, "Queue Selected"); }) + 1; + idx = idx || existingOptions.findIndex((o) => { var _a, _b; return (_b = (_a = o === null || o === void 0 ? void 0 : o.content) === null || _a === void 0 ? void 0 : _a.startsWith) === null || _b === void 0 ? void 0 : _b.call(_a, "Convert to Group"); }); + idx = idx || existingOptions.findIndex((o) => { var _a, _b; return (_b = (_a = o === null || o === void 0 ? void 0 : o.content) === null || _a === void 0 ? void 0 : _a.startsWith) === null || _b === void 0 ? void 0 : _b.call(_a, "Arrange ("); }); + idx = idx || existingOptions.findIndex((o) => !o) + 1; + idx = idx || 3; + existingOptions.splice(idx, 0, ...options); + for (let i = existingOptions.length; i > 0; i--) { + if (existingOptions[i] === null && existingOptions[i + 1] === null) { + existingOptions.splice(i, 1); + } + } + return existingOptions; + }; + }, 1016); + } + getRgthreeContextMenuItems() { + const [canvas, graph] = [app.canvas, app.graph]; + const selectedNodes = Object.values(canvas.selected_nodes || {}); + let rerouteNodes = []; + if (selectedNodes.length) { + rerouteNodes = selectedNodes.filter((n) => n.type === "Reroute"); + } + else { + rerouteNodes = graph._nodes.filter((n) => n.type == "Reroute"); + } + const rerouteLabel = selectedNodes.length ? "selected" : "all"; + const showBookmarks = CONFIG_SERVICE.getFeatureValue("menu_bookmarks.enabled"); + const bookmarkMenuItems = showBookmarks ? getBookmarks() : []; + return [ + { + content: "Nodes", + disabled: true, + className: "rgthree-contextmenu-item rgthree-contextmenu-label", + }, + { + content: iconNode + "All", + className: "rgthree-contextmenu-item", + has_submenu: true, + submenu: { + options: getNodeTypeStrings(), + callback: (value, options, event) => { + const node = LiteGraph.createNode(addRgthree(value)); + node.pos = [ + rgthree.lastAdjustedMouseEvent.canvasX, + rgthree.lastAdjustedMouseEvent.canvasY, + ]; + canvas.graph.add(node); + canvas.selectNode(node); + app.graph.setDirtyCanvas(true, true); + }, + extra: { rgthree_doNotNest: true }, + }, + }, + { + content: "Actions", + disabled: true, + className: "rgthree-contextmenu-item rgthree-contextmenu-label", + }, + { + content: iconGear + "Settings (rgthree-comfy)", + disabled: !!this.settingsDialog, + className: "rgthree-contextmenu-item", + callback: (...args) => { + this.settingsDialog = new RgthreeConfigDialog().show(); + this.settingsDialog.addEventListener("close", (e) => { + this.settingsDialog = null; + }); + }, + }, + { + content: iconReplace + ` Convert ${rerouteLabel} Reroutes`, + disabled: !rerouteNodes.length, + className: "rgthree-contextmenu-item", + callback: (...args) => { + const msg = `Convert ${rerouteLabel} ComfyUI Reroutes to Reroute (rgthree) nodes? \n` + + `(First save a copy of your workflow & check reroute connections afterwards)`; + if (!window.confirm(msg)) { + return; + } + (async () => { + for (const node of [...rerouteNodes]) { + if (node.type == "Reroute") { + this.replacingReroute = node.id; + await replaceNode(node, NodeTypesString.REROUTE); + this.replacingReroute = null; + } + } + })(); + }, + }, + ...bookmarkMenuItems, + { + content: "More...", + disabled: true, + className: "rgthree-contextmenu-item rgthree-contextmenu-label", + }, + { + content: iconStarFilled + "Star on Github", + className: "rgthree-contextmenu-item rgthree-contextmenu-github", + callback: (...args) => { + window.open("https://github.com/rgthree/rgthree-comfy", "_blank"); + }, + }, + ]; + } + async queueOutputNodes(nodeIds) { + var _a; + try { + this.queueNodeIds = nodeIds; + await app.queuePrompt(); + } + catch (e) { + const [n, v] = this.logParts(LogLevel.ERROR, `There was an error queuing nodes ${nodeIds}`, e); + (_a = console[n]) === null || _a === void 0 ? void 0 : _a.call(console, ...v); + } + finally { + this.queueNodeIds = null; + } + } + recursiveAddNodes(nodeId, oldOutput, newOutput) { + let currentId = nodeId; + let currentNode = oldOutput[currentId]; + if (newOutput[currentId] == null) { + newOutput[currentId] = currentNode; + for (const inputValue of Object.values(currentNode.inputs || [])) { + if (Array.isArray(inputValue)) { + this.recursiveAddNodes(inputValue[0], oldOutput, newOutput); + } + } + } + return newOutput; + } + initializeComfyUIHooks() { + const rgthree = this; + const queuePrompt = app.queuePrompt; + app.queuePrompt = async function () { + rgthree.processingQueue = true; + rgthree.dispatchCustomEvent("queue"); + try { + await queuePrompt.apply(app, [...arguments]); + } + finally { + rgthree.processingQueue = false; + rgthree.dispatchCustomEvent("queue-end"); + } + }; + const loadApiJson = app.loadApiJson; + app.loadApiJson = async function () { + rgthree.loadingApiJson = true; + try { + loadApiJson.apply(app, [...arguments]); + } + finally { + rgthree.loadingApiJson = false; + } + }; + const graphToPrompt = app.graphToPrompt; + app.graphToPrompt = async function () { + rgthree.dispatchCustomEvent("graph-to-prompt"); + let promise = graphToPrompt.apply(app, [...arguments]); + await promise; + rgthree.dispatchCustomEvent("graph-to-prompt-end"); + return promise; + }; + const apiQueuePrompt = api.queuePrompt; + api.queuePrompt = async function (index, prompt) { + var _a; + if (((_a = rgthree.queueNodeIds) === null || _a === void 0 ? void 0 : _a.length) && prompt.output) { + const oldOutput = prompt.output; + let newOutput = {}; + for (const queueNodeId of rgthree.queueNodeIds) { + rgthree.recursiveAddNodes(String(queueNodeId), oldOutput, newOutput); + } + prompt.output = newOutput; + } + rgthree.dispatchCustomEvent("comfy-api-queue-prompt-before", { + workflow: prompt.workflow, + output: prompt.output, + }); + const response = apiQueuePrompt.apply(app, [index, prompt]); + rgthree.dispatchCustomEvent("comfy-api-queue-prompt-end"); + return response; + }; + const clean = app.clean; + app.clean = function () { + rgthree.clearAllMessages(); + clean && clean.apply(app, [...arguments]); + }; + const loadGraphData = app.loadGraphData; + app.loadGraphData = function (graph) { + if (rgthree.monitorLinkTimeout) { + clearTimeout(rgthree.monitorLinkTimeout); + rgthree.monitorLinkTimeout = null; + } + rgthree.clearAllMessages(); + let graphCopy; + try { + graphCopy = JSON.parse(JSON.stringify(graph)); + } + catch (e) { + graphCopy = null; + } + setTimeout(() => { + var _a, _b, _c; + const wasLoadingAborted = (_b = (_a = document + .querySelector(".comfy-modal-content")) === null || _a === void 0 ? void 0 : _a.textContent) === null || _b === void 0 ? void 0 : _b.includes("Loading aborted due"); + const graphToUse = wasLoadingAborted ? graphCopy || graph : app.graph; + const fixBadLinksResult = fixBadLinks(graphToUse); + if (fixBadLinksResult.hasBadLinks) { + const [n, v] = rgthree.logParts(LogLevel.WARN, `The workflow you've loaded has corrupt linking data. Open ${new URL(location.href).origin}/rgthree/link_fixer to try to fix.`); + (_c = console[n]) === null || _c === void 0 ? void 0 : _c.call(console, ...v); + if (CONFIG_SERVICE.getConfigValue("features.show_alerts_for_corrupt_workflows")) { + rgthree.showMessage({ + id: "bad-links", + type: "warn", + message: "The workflow you've loaded has corrupt linking data that may be able to be fixed.", + actions: [ + { + label: "Open fixer", + href: "/rgthree/link_fixer", + }, + { + label: "Fix in place", + href: "/rgthree/link_fixer", + callback: (event) => { + event.stopPropagation(); + event.preventDefault(); + if (confirm("This will attempt to fix in place. Please make sure to have a saved copy of your workflow.")) { + try { + const fixBadLinksResult = fixBadLinks(graphToUse, true); + if (!fixBadLinksResult.hasBadLinks) { + rgthree.hideMessage("bad-links"); + alert("Success! It's possible some valid links may have been affected. Please check and verify your workflow."); + wasLoadingAborted && app.loadGraphData(fixBadLinksResult.graph); + if (CONFIG_SERVICE.getConfigValue("features.monitor_for_corrupt_links") || + CONFIG_SERVICE.getConfigValue("features.monitor_bad_links")) { + rgthree.monitorLinkTimeout = setTimeout(() => { + rgthree.monitorBadLinks(); + }, 5000); + } + } + } + catch (e) { + console.error(e); + alert("Unsuccessful at fixing corrupt data. :("); + rgthree.hideMessage("bad-links"); + } + } + }, + }, + ], + }); + } + } + else if (CONFIG_SERVICE.getConfigValue("features.monitor_for_corrupt_links") || + CONFIG_SERVICE.getConfigValue("features.monitor_bad_links")) { + rgthree.monitorLinkTimeout = setTimeout(() => { + rgthree.monitorBadLinks(); + }, 5000); + } + }, 100); + return loadGraphData && loadGraphData.apply(app, [...arguments]); + }; + } + getNodeFromInitialGraphToPromptSerializedWorkflowBecauseComfyUIBrokeStuff(node) { + var _a, _b, _c; + return ((_c = (_b = (_a = this.initialGraphToPromptSerializedWorkflowBecauseComfyUIBrokeStuff) === null || _a === void 0 ? void 0 : _a.nodes) === null || _b === void 0 ? void 0 : _b.find((n) => n.id === node.id)) !== null && _c !== void 0 ? _c : null); + } + async showMessage(data) { + let container = document.querySelector(".rgthree-top-messages-container"); + if (!container) { + container = document.createElement("div"); + container.classList.add("rgthree-top-messages-container"); + document.body.appendChild(container); + } + const dialogs = query("dialog[open]"); + if (dialogs.length) { + let dialog = dialogs[dialogs.length - 1]; + dialog.appendChild(container); + dialog.addEventListener("close", (e) => { + document.body.appendChild(container); + }); + } + await this.hideMessage(data.id); + const messageContainer = document.createElement("div"); + messageContainer.setAttribute("type", data.type || "info"); + const message = document.createElement("span"); + message.innerHTML = data.message; + messageContainer.appendChild(message); + for (let a = 0; a < (data.actions || []).length; a++) { + const action = data.actions[a]; + if (a > 0) { + const sep = document.createElement("span"); + sep.innerHTML = " | "; + messageContainer.appendChild(sep); + } + const actionEl = document.createElement("a"); + actionEl.innerText = action.label; + if (action.href) { + actionEl.target = "_blank"; + actionEl.href = action.href; + } + if (action.callback) { + actionEl.onclick = (e) => { + return action.callback(e); + }; + } + messageContainer.appendChild(actionEl); + } + const messageAnimContainer = document.createElement("div"); + messageAnimContainer.setAttribute("msg-id", data.id); + messageAnimContainer.appendChild(messageContainer); + container.appendChild(messageAnimContainer); + await wait(64); + messageAnimContainer.style.marginTop = `-${messageAnimContainer.offsetHeight}px`; + await wait(64); + messageAnimContainer.classList.add("-show"); + if (data.timeout) { + await wait(data.timeout); + this.hideMessage(data.id); + } + } + async hideMessage(id) { + const msg = document.querySelector(`.rgthree-top-messages-container > [msg-id="${id}"]`); + if (msg === null || msg === void 0 ? void 0 : msg.classList.contains("-show")) { + msg.classList.remove("-show"); + await wait(750); + } + msg && msg.remove(); + } + async clearAllMessages() { + let container = document.querySelector(".rgthree-top-messages-container"); + container && (container.innerHTML = ""); + } + setLogLevel(level) { + if (typeof level === "string") { + level = LogLevelKeyToLogLevel[CONFIG_SERVICE.getConfigValue("log_level")]; + } + if (level != null) { + GLOBAL_LOG_LEVEL = level; + } + } + logParts(level, message, ...args) { + return this.logger.logParts(level, message, ...args); + } + newLogSession(name) { + return this.logger.newSession(name); + } + isDevMode() { + if (window.location.href.includes("rgthree-dev=false")) { + return false; + } + return GLOBAL_LOG_LEVEL >= LogLevel.DEBUG || window.location.href.includes("rgthree-dev"); + } + isDebugMode() { + if (!this.isDevMode() || window.location.href.includes("rgthree-debug=false")) { + return false; + } + return window.location.href.includes("rgthree-debug"); + } + monitorBadLinks() { + const badLinksFound = fixBadLinks(app.graph); + if (badLinksFound.hasBadLinks && !this.monitorBadLinksAlerted) { + this.monitorBadLinksAlerted = true; + alert(`Problematic links just found in live data. Can you save your workflow and file a bug with ` + + `the last few steps you took to trigger this at ` + + `https://github.com/rgthree/rgthree-comfy/issues. Thank you!`); + } + else if (!badLinksFound.hasBadLinks) { + this.monitorBadLinksAlerted = false; + } + this.monitorLinkTimeout = setTimeout(() => { + this.monitorBadLinks(); + }, 5000); + } +} +function getBookmarks() { + const graph = app.graph; + const bookmarks = graph._nodes + .filter((n) => n.type === NodeTypesString.BOOKMARK) + .sort((a, b) => a.title.localeCompare(b.title)) + .map((n) => ({ + content: `[${n.shortcutKey}] ${n.title}`, + className: "rgthree-contextmenu-item", + callback: () => { + n.canvasToBookmark(); + }, + })); + return !bookmarks.length + ? [] + : [ + { + content: "🔖 Bookmarks", + disabled: true, + className: "rgthree-contextmenu-item rgthree-contextmenu-label", + }, + ...bookmarks, + ]; +} +export const rgthree = new Rgthree(); +window.rgthree = rgthree; diff --git a/rgthree-comfy/web/comfyui/seed.js b/rgthree-comfy/web/comfyui/seed.js new file mode 100644 index 0000000000000000000000000000000000000000..65bd1c3fbe2433b754cfdf1a43a2f40afb2ccb18 --- /dev/null +++ b/rgthree-comfy/web/comfyui/seed.js @@ -0,0 +1,188 @@ +import { app } from "../../scripts/app.js"; +import { ComfyWidgets } from "../../scripts/widgets.js"; +import { RgthreeBaseServerNode } from "./base_node.js"; +import { rgthree } from "./rgthree.js"; +import { addConnectionLayoutSupport } from "./utils.js"; +import { NodeTypesString } from "./constants.js"; +const LAST_SEED_BUTTON_LABEL = "♻️ (Use Last Queued Seed)"; +const SPECIAL_SEED_RANDOM = -1; +const SPECIAL_SEED_INCREMENT = -2; +const SPECIAL_SEED_DECREMENT = -3; +const SPECIAL_SEEDS = [SPECIAL_SEED_RANDOM, SPECIAL_SEED_INCREMENT, SPECIAL_SEED_DECREMENT]; +class RgthreeSeed extends RgthreeBaseServerNode { + constructor(title = RgthreeSeed.title) { + super(title); + this.serialize_widgets = true; + this.logger = rgthree.newLogSession(`[Seed]`); + this.lastSeed = undefined; + this.serializedCtx = {}; + this.lastSeedValue = null; + this.randMax = 1125899906842624; + this.randMin = 0; + this.randomRange = 1125899906842624; + this.handleApiHijackingBound = this.handleApiHijacking.bind(this); + rgthree.addEventListener("comfy-api-queue-prompt-before", this.handleApiHijackingBound); + } + onRemoved() { + rgthree.addEventListener("comfy-api-queue-prompt-before", this.handleApiHijackingBound); + } + configure(info) { + var _a; + super.configure(info); + if ((_a = this.properties) === null || _a === void 0 ? void 0 : _a["showLastSeed"]) { + this.addLastSeedValue(); + } + } + async handleAction(action) { + if (action === "Randomize Each Time") { + this.seedWidget.value = SPECIAL_SEED_RANDOM; + } + else if (action === "Use Last Queued Seed") { + this.seedWidget.value = this.lastSeed != null ? this.lastSeed : this.seedWidget.value; + this.lastSeedButton.name = LAST_SEED_BUTTON_LABEL; + this.lastSeedButton.disabled = true; + } + } + onNodeCreated() { + var _a; + (_a = super.onNodeCreated) === null || _a === void 0 ? void 0 : _a.call(this); + for (const [i, w] of this.widgets.entries()) { + if (w.name === "seed") { + this.seedWidget = w; + this.seedWidget.value = SPECIAL_SEED_RANDOM; + } + else if (w.name === "control_after_generate") { + this.widgets.splice(i, 1); + } + } + let step = this.seedWidget.options.step || 1; + this.randMax = Math.min(1125899906842624, this.seedWidget.options.max); + this.randMin = Math.max(0, this.seedWidget.options.min); + this.randomRange = (this.randMax - Math.max(0, this.randMin)) / (step / 10); + this.addWidget("button", "🎲 Randomize Each Time", null, () => { + this.seedWidget.value = SPECIAL_SEED_RANDOM; + }, { serialize: false }); + this.addWidget("button", "🎲 New Fixed Random", null, () => { + this.seedWidget.value = + Math.floor(Math.random() * this.randomRange) * (step / 10) + this.randMin; + }, { serialize: false }); + this.lastSeedButton = this.addWidget("button", LAST_SEED_BUTTON_LABEL, null, () => { + this.seedWidget.value = this.lastSeed != null ? this.lastSeed : this.seedWidget.value; + this.lastSeedButton.name = LAST_SEED_BUTTON_LABEL; + this.lastSeedButton.disabled = true; + }, { width: 50, serialize: false }); + this.lastSeedButton.disabled = true; + } + getExtraMenuOptions(canvas, options) { + var _a; + (_a = super.getExtraMenuOptions) === null || _a === void 0 ? void 0 : _a.apply(this, [...arguments]); + options.splice(options.length - 1, 0, { + content: "Show/Hide Last Seed Value", + callback: (_value, _options, _event, _parentMenu, _node) => { + this.properties["showLastSeed"] = !this.properties["showLastSeed"]; + if (this.properties["showLastSeed"]) { + this.addLastSeedValue(); + } + else { + this.removeLastSeedValue(); + } + }, + }); + } + addLastSeedValue() { + if (this.lastSeedValue) + return; + this.lastSeedValue = ComfyWidgets["STRING"](this, "last_seed", ["STRING", { multiline: true }], app).widget; + this.lastSeedValue.inputEl.readOnly = true; + this.lastSeedValue.inputEl.style.fontSize = "0.75rem"; + this.lastSeedValue.inputEl.style.textAlign = "center"; + this.computeSize(); + } + removeLastSeedValue() { + if (!this.lastSeedValue) + return; + this.lastSeedValue.inputEl.remove(); + this.widgets.splice(this.widgets.indexOf(this.lastSeedValue), 1); + this.lastSeedValue = null; + this.computeSize(); + } + handleApiHijacking(e) { + var _a, _b, _c, _d; + if (this.mode === LiteGraph.NEVER || this.mode === 4) { + return; + } + const workflow = e.detail.workflow; + const output = e.detail.output; + let workflowNode = (_b = (_a = workflow === null || workflow === void 0 ? void 0 : workflow.nodes) === null || _a === void 0 ? void 0 : _a.find((n) => n.id === this.id)) !== null && _b !== void 0 ? _b : null; + let outputInputs = (_c = output === null || output === void 0 ? void 0 : output[this.id]) === null || _c === void 0 ? void 0 : _c.inputs; + if (!workflowNode || + !outputInputs || + outputInputs[this.seedWidget.name || "seed"] === undefined) { + const [n, v] = this.logger.warnParts(`Node ${this.id} not found in prompt data sent to server. This may be fine if only ` + + `queuing part of the workflow. If not, then this could be a bug.`); + (_d = console[n]) === null || _d === void 0 ? void 0 : _d.call(console, ...v); + return; + } + const seedToUse = this.getSeedToUse(); + const seedWidgetndex = this.widgets.indexOf(this.seedWidget); + workflowNode.widgets_values[seedWidgetndex] = seedToUse; + outputInputs[this.seedWidget.name || "seed"] = seedToUse; + this.lastSeed = seedToUse; + if (seedToUse != this.seedWidget.value) { + this.lastSeedButton.name = `♻️ ${this.lastSeed}`; + this.lastSeedButton.disabled = false; + } + else { + this.lastSeedButton.name = LAST_SEED_BUTTON_LABEL; + this.lastSeedButton.disabled = true; + } + if (this.lastSeedValue) { + this.lastSeedValue.value = `Last Seed: ${this.lastSeed}`; + } + } + getSeedToUse() { + const inputSeed = this.seedWidget.value; + let seedToUse = null; + if (SPECIAL_SEEDS.includes(inputSeed)) { + if (typeof this.lastSeed === "number" && !SPECIAL_SEEDS.includes(this.lastSeed)) { + if (inputSeed === SPECIAL_SEED_INCREMENT) { + seedToUse = this.lastSeed + 1; + } + else if (inputSeed === SPECIAL_SEED_DECREMENT) { + seedToUse = this.lastSeed - 1; + } + } + if (seedToUse == null || SPECIAL_SEEDS.includes(seedToUse)) { + seedToUse = + Math.floor(Math.random() * this.randomRange) * + ((this.seedWidget.options.step || 1) / 10) + + this.randMin; + } + } + return seedToUse !== null && seedToUse !== void 0 ? seedToUse : inputSeed; + } + static setUp(comfyClass, nodeData) { + RgthreeBaseServerNode.registerForOverride(comfyClass, nodeData, RgthreeSeed); + } + static onRegisteredForOverride(comfyClass, ctxClass) { + addConnectionLayoutSupport(RgthreeSeed, app, [ + ["Left", "Right"], + ["Right", "Left"], + ]); + setTimeout(() => { + RgthreeSeed.category = comfyClass.category; + }); + } +} +RgthreeSeed.title = NodeTypesString.SEED; +RgthreeSeed.type = NodeTypesString.SEED; +RgthreeSeed.comfyClass = NodeTypesString.SEED; +RgthreeSeed.exposedActions = ["Randomize Each Time", "Use Last Queued Seed"]; +app.registerExtension({ + name: "rgthree.Seed", + async beforeRegisterNodeDef(nodeType, nodeData) { + if (nodeData.name === RgthreeSeed.type) { + RgthreeSeed.setUp(nodeType, nodeData); + } + }, +}); diff --git a/rgthree-comfy/web/comfyui/services/bookmarks_services.js b/rgthree-comfy/web/comfyui/services/bookmarks_services.js new file mode 100644 index 0000000000000000000000000000000000000000..59ac6e2c4dd2d6e58f36502e74e3fb3413d68266 --- /dev/null +++ b/rgthree-comfy/web/comfyui/services/bookmarks_services.js @@ -0,0 +1,10 @@ +import { app } from "../../../scripts/app.js"; +import { NodeTypesString } from "../constants.js"; +class BookmarksService { + getCurrentBookmarks() { + return app.graph._nodes + .filter((n) => n.type === NodeTypesString.BOOKMARK) + .sort((a, b) => a.title.localeCompare(b.title)); + } +} +export const SERVICE = new BookmarksService(); diff --git a/rgthree-comfy/web/comfyui/services/config_service.js b/rgthree-comfy/web/comfyui/services/config_service.js new file mode 100644 index 0000000000000000000000000000000000000000..34f1087173c01cda464a519f4d8ebdfb81fbf25a --- /dev/null +++ b/rgthree-comfy/web/comfyui/services/config_service.js @@ -0,0 +1,28 @@ +import { rgthreeConfig } from "../../../rgthree/config.js"; +import { getObjectValue, setObjectValue } from "../../../rgthree/common/shared_utils.js"; +import { rgthreeApi } from "../../../rgthree/common/rgthree_api.js"; +class ConfigService extends EventTarget { + getConfigValue(key, def) { + return getObjectValue(rgthreeConfig, key, def); + } + getFeatureValue(key, def) { + key = "features." + key.replace(/^features\./, ""); + return getObjectValue(rgthreeConfig, key, def); + } + async setConfigValues(changed) { + const body = new FormData(); + body.append("json", JSON.stringify(changed)); + const response = await rgthreeApi.fetchJson("/config", { method: "POST", body }); + if (response.status === "ok") { + for (const [key, value] of Object.entries(changed)) { + setObjectValue(rgthreeConfig, key, value); + this.dispatchEvent(new CustomEvent("config-change", { detail: { key, value } })); + } + } + else { + return false; + } + return true; + } +} +export const SERVICE = new ConfigService(); diff --git a/rgthree-comfy/web/comfyui/services/context_service.js b/rgthree-comfy/web/comfyui/services/context_service.js new file mode 100644 index 0000000000000000000000000000000000000000..27e19658cbdc20d6213cd673fa6f10f79ee9ee7a --- /dev/null +++ b/rgthree-comfy/web/comfyui/services/context_service.js @@ -0,0 +1,51 @@ +import { getConnectedOutputNodesAndFilterPassThroughs } from "../utils.js"; +export let SERVICE; +const OWNED_PREFIX = "+"; +const REGEX_PREFIX = /^[\+⚠️]\s*/; +const REGEX_EMPTY_INPUT = /^\+\s*$/; +export function stripContextInputPrefixes(name) { + return name.replace(REGEX_PREFIX, ""); +} +export function getContextOutputName(inputName) { + if (inputName === "base_ctx") + return "CONTEXT"; + return stripContextInputPrefixes(inputName).toUpperCase(); +} +export var InputMutationOperation; +(function (InputMutationOperation) { + InputMutationOperation[InputMutationOperation["UNKNOWN"] = 0] = "UNKNOWN"; + InputMutationOperation[InputMutationOperation["ADDED"] = 1] = "ADDED"; + InputMutationOperation[InputMutationOperation["REMOVED"] = 2] = "REMOVED"; + InputMutationOperation[InputMutationOperation["RENAMED"] = 3] = "RENAMED"; +})(InputMutationOperation || (InputMutationOperation = {})); +export class ContextService { + constructor() { + if (SERVICE) { + throw new Error("ContextService was already instantiated."); + } + } + onInputChanges(node, mutation) { + const childCtxs = getConnectedOutputNodesAndFilterPassThroughs(node, node, 0); + for (const childCtx of childCtxs) { + childCtx.handleUpstreamMutation(mutation); + } + } + getDynamicContextInputsData(node) { + return node + .getContextInputsList() + .map((input, index) => ({ + name: stripContextInputPrefixes(input.name), + type: String(input.type), + index, + })) + .filter((i) => i.type !== "*"); + } + getDynamicContextOutputsData(node) { + return node.outputs.map((output, index) => ({ + name: stripContextInputPrefixes(output.name), + type: String(output.type), + index, + })); + } +} +SERVICE = new ContextService(); diff --git a/rgthree-comfy/web/comfyui/services/fast_groups_service.js b/rgthree-comfy/web/comfyui/services/fast_groups_service.js new file mode 100644 index 0000000000000000000000000000000000000000..acc8eb15610d547d7b3575adce205a0f6612c654 --- /dev/null +++ b/rgthree-comfy/web/comfyui/services/fast_groups_service.js @@ -0,0 +1,138 @@ +import { app } from "../../../scripts/app.js"; +class FastGroupsService { + constructor() { + this.msThreshold = 400; + this.msLastUnsorted = 0; + this.msLastAlpha = 0; + this.msLastPosition = 0; + this.groupsUnsorted = []; + this.groupsSortedAlpha = []; + this.groupsSortedPosition = []; + this.fastGroupNodes = []; + this.runScheduledForMs = null; + this.runScheduleTimeout = null; + this.runScheduleAnimation = null; + this.cachedNodeBoundings = null; + } + addFastGroupNode(node) { + this.fastGroupNodes.push(node); + this.scheduleRun(8); + } + removeFastGroupNode(node) { + var _a; + const index = this.fastGroupNodes.indexOf(node); + if (index > -1) { + this.fastGroupNodes.splice(index, 1); + } + if (!((_a = this.fastGroupNodes) === null || _a === void 0 ? void 0 : _a.length)) { + this.clearScheduledRun(); + this.groupsUnsorted = []; + this.groupsSortedAlpha = []; + this.groupsSortedPosition = []; + } + } + run() { + if (!this.runScheduledForMs) { + return; + } + for (const node of this.fastGroupNodes) { + node.refreshWidgets(); + } + this.clearScheduledRun(); + this.scheduleRun(); + } + scheduleRun(ms = 500) { + if (this.runScheduledForMs && ms < this.runScheduledForMs) { + this.clearScheduledRun(); + } + if (!this.runScheduledForMs && this.fastGroupNodes.length) { + this.runScheduledForMs = ms; + this.runScheduleTimeout = setTimeout(() => { + this.runScheduleAnimation = requestAnimationFrame(() => this.run()); + }, ms); + } + } + clearScheduledRun() { + this.runScheduleTimeout && clearTimeout(this.runScheduleTimeout); + this.runScheduleAnimation && cancelAnimationFrame(this.runScheduleAnimation); + this.runScheduleTimeout = null; + this.runScheduleAnimation = null; + this.runScheduledForMs = null; + } + getBoundingsForAllNodes() { + if (!this.cachedNodeBoundings) { + this.cachedNodeBoundings = {}; + for (const node of app.graph._nodes) { + this.cachedNodeBoundings[node.id] = node.getBounding(); + } + setTimeout(() => { + this.cachedNodeBoundings = null; + }, 50); + } + return this.cachedNodeBoundings; + } + recomputeInsideNodesForGroup(group) { + const cachedBoundings = this.getBoundingsForAllNodes(); + const nodes = group.graph._nodes; + group._nodes.length = 0; + for (const node of nodes) { + const node_bounding = cachedBoundings[node.id]; + if (!node_bounding || !LiteGraph.overlapBounding(group._bounding, node_bounding)) { + continue; + } + group._nodes.push(node); + } + } + getGroupsUnsorted(now) { + const canvas = app.canvas; + const graph = app.graph; + if (!canvas.selected_group_moving && + (!this.groupsUnsorted.length || now - this.msLastUnsorted > this.msThreshold)) { + this.groupsUnsorted = [...graph._groups]; + for (const group of this.groupsUnsorted) { + this.recomputeInsideNodesForGroup(group); + group._rgthreeHasAnyActiveNode = group._nodes.some((n) => n.mode === LiteGraph.ALWAYS); + } + this.msLastUnsorted = now; + } + return this.groupsUnsorted; + } + getGroupsAlpha(now) { + const graph = app.graph; + if (!this.groupsSortedAlpha.length || now - this.msLastAlpha > this.msThreshold) { + this.groupsSortedAlpha = [...this.getGroupsUnsorted(now)].sort((a, b) => { + return a.title.localeCompare(b.title); + }); + this.msLastAlpha = now; + } + return this.groupsSortedAlpha; + } + getGroupsPosition(now) { + const graph = app.graph; + if (!this.groupsSortedPosition.length || now - this.msLastPosition > this.msThreshold) { + this.groupsSortedPosition = [...this.getGroupsUnsorted(now)].sort((a, b) => { + const aY = Math.floor(a._pos[1] / 30); + const bY = Math.floor(b._pos[1] / 30); + if (aY == bY) { + const aX = Math.floor(a._pos[0] / 30); + const bX = Math.floor(b._pos[0] / 30); + return aX - bX; + } + return aY - bY; + }); + this.msLastPosition = now; + } + return this.groupsSortedPosition; + } + getGroups(sort) { + const now = +new Date(); + if (sort === "alphanumeric") { + return this.getGroupsAlpha(now); + } + if (sort === "position") { + return this.getGroupsPosition(now); + } + return this.getGroupsUnsorted(now); + } +} +export const SERVICE = new FastGroupsService(); diff --git a/rgthree-comfy/web/comfyui/services/key_events_services.js b/rgthree-comfy/web/comfyui/services/key_events_services.js new file mode 100644 index 0000000000000000000000000000000000000000..56619db62e81ec8f4824fb3e9643bfe7507926ad --- /dev/null +++ b/rgthree-comfy/web/comfyui/services/key_events_services.js @@ -0,0 +1,105 @@ +class KeyEventService extends EventTarget { + constructor() { + var _a, _b, _c; + super(); + this.downKeys = {}; + this.ctrlKey = false; + this.altKey = false; + this.metaKey = false; + this.shiftKey = false; + this.isMac = !!(((_a = navigator.platform) === null || _a === void 0 ? void 0 : _a.toLocaleUpperCase().startsWith("MAC")) || + ((_c = (_b = navigator.userAgentData) === null || _b === void 0 ? void 0 : _b.platform) === null || _c === void 0 ? void 0 : _c.toLocaleUpperCase().startsWith("MAC"))); + this.initialize(); + } + initialize() { + const that = this; + const processKey = LGraphCanvas.prototype.processKey; + LGraphCanvas.prototype.processKey = function (e) { + if (e.type === "keydown" || e.type === "keyup") { + that.handleKeyDownOrUp(e); + } + return processKey.apply(this, [...arguments]); + }; + window.addEventListener("keydown", (e) => { + that.handleKeyDownOrUp(e); + }); + window.addEventListener("keyup", (e) => { + that.handleKeyDownOrUp(e); + }); + document.addEventListener("visibilitychange", (e) => { + this.clearKeydowns(); + }); + window.addEventListener("blur", (e) => { + this.clearKeydowns(); + }); + } + handleKeyDownOrUp(e) { + const key = e.key.toLocaleUpperCase(); + if ((e.type === 'keydown' && this.downKeys[key] === true) + || (e.type === 'keyup' && this.downKeys[key] === undefined)) { + return; + } + this.ctrlKey = !!e.ctrlKey; + this.altKey = !!e.altKey; + this.metaKey = !!e.metaKey; + this.shiftKey = !!e.shiftKey; + if (e.type === "keydown") { + this.downKeys[key] = true; + this.dispatchCustomEvent("keydown", { originalEvent: e }); + } + else if (e.type === "keyup") { + if (key === "META" && this.isMac) { + this.clearKeydowns(); + } + else { + delete this.downKeys[key]; + } + this.dispatchCustomEvent("keyup", { originalEvent: e }); + } + } + clearKeydowns() { + this.ctrlKey = false; + this.altKey = false; + this.metaKey = false; + this.shiftKey = false; + for (const key in this.downKeys) + delete this.downKeys[key]; + } + dispatchCustomEvent(event, detail) { + if (detail != null) { + return this.dispatchEvent(new CustomEvent(event, { detail })); + } + return this.dispatchEvent(new CustomEvent(event)); + } + getKeysFromShortcut(shortcut) { + let keys; + if (typeof shortcut === "string") { + shortcut = shortcut.replace(/\s/g, ""); + shortcut = shortcut.replace(/^\+/, "__PLUS__").replace(/\+\+/, "+__PLUS__"); + keys = shortcut.split("+").map((i) => i.replace("__PLUS__", "+")); + } + else { + keys = [...shortcut]; + } + return keys.map((k) => k.toLocaleUpperCase()); + } + areAllKeysDown(keys) { + keys = this.getKeysFromShortcut(keys); + return keys.every((k) => { + return this.downKeys[k]; + }); + } + areOnlyKeysDown(keys, alsoAllowShift = false) { + keys = this.getKeysFromShortcut(keys); + const allKeysDown = this.areAllKeysDown(keys); + const downKeysLength = Object.values(this.downKeys).length; + if (allKeysDown && keys.length === downKeysLength) { + return true; + } + if (alsoAllowShift && !keys.includes("SHIFT") && keys.length === downKeysLength - 1) { + return allKeysDown && this.areAllKeysDown(["SHIFT"]); + } + return false; + } +} +export const SERVICE = new KeyEventService(); diff --git a/rgthree-comfy/web/comfyui/utils.js b/rgthree-comfy/web/comfyui/utils.js new file mode 100644 index 0000000000000000000000000000000000000000..dad76b2ff56b9de304a17d8a5a0e51d1662d0e59 --- /dev/null +++ b/rgthree-comfy/web/comfyui/utils.js @@ -0,0 +1,630 @@ +import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js"; +import { getResolver, wait } from "../../rgthree/common/shared_utils.js"; +import { RgthreeHelpDialog } from "../../rgthree/common/dialog.js"; +const oldApiGetNodeDefs = api.getNodeDefs; +api.getNodeDefs = async function () { + const defs = await oldApiGetNodeDefs.call(api); + this.dispatchEvent(new CustomEvent("fresh-node-defs", { detail: defs })); + return defs; +}; +export var IoDirection; +(function (IoDirection) { + IoDirection[IoDirection["INPUT"] = 0] = "INPUT"; + IoDirection[IoDirection["OUTPUT"] = 1] = "OUTPUT"; +})(IoDirection || (IoDirection = {})); +const PADDING = 0; +export const LAYOUT_LABEL_TO_DATA = { + Left: [LiteGraph.LEFT, [0, 0.5], [PADDING, 0]], + Right: [LiteGraph.RIGHT, [1, 0.5], [-PADDING, 0]], + Top: [LiteGraph.UP, [0.5, 0], [0, PADDING]], + Bottom: [LiteGraph.DOWN, [0.5, 1], [0, -PADDING]], +}; +export const LAYOUT_LABEL_OPPOSITES = { + Left: "Right", + Right: "Left", + Top: "Bottom", + Bottom: "Top", +}; +export const LAYOUT_CLOCKWISE = ["Top", "Right", "Bottom", "Left"]; +export function addMenuItem(node, _app, config, after = "Shape") { + const oldGetExtraMenuOptions = node.prototype.getExtraMenuOptions; + node.prototype.getExtraMenuOptions = function (canvas, menuOptions) { + oldGetExtraMenuOptions && oldGetExtraMenuOptions.apply(this, [canvas, menuOptions]); + addMenuItemOnExtraMenuOptions(this, config, menuOptions, after); + }; +} +let canvasResolver = null; +export function waitForCanvas() { + if (canvasResolver === null) { + canvasResolver = getResolver(); + function _waitForCanvas() { + if (!canvasResolver.completed) { + if (app === null || app === void 0 ? void 0 : app.canvas) { + canvasResolver.resolve(app.canvas); + } + else { + requestAnimationFrame(_waitForCanvas); + } + } + } + _waitForCanvas(); + } + return canvasResolver.promise; +} +let graphResolver = null; +export function waitForGraph() { + if (graphResolver === null) { + graphResolver = getResolver(); + function _wait() { + if (!graphResolver.completed) { + if (app === null || app === void 0 ? void 0 : app.graph) { + graphResolver.resolve(app.graph); + } + else { + requestAnimationFrame(_wait); + } + } + } + _wait(); + } + return graphResolver.promise; +} +export function addMenuItemOnExtraMenuOptions(node, config, menuOptions, after = "Shape") { + let idx = menuOptions + .slice() + .reverse() + .findIndex((option) => option === null || option === void 0 ? void 0 : option.isRgthree); + if (idx == -1) { + idx = menuOptions.findIndex((option) => { var _a; return (_a = option === null || option === void 0 ? void 0 : option.content) === null || _a === void 0 ? void 0 : _a.includes(after); }) + 1; + if (!idx) { + idx = menuOptions.length - 1; + } + menuOptions.splice(idx, 0, null); + idx++; + } + else { + idx = menuOptions.length - idx; + } + const subMenuOptions = typeof config.subMenuOptions === "function" + ? config.subMenuOptions(node) + : config.subMenuOptions; + menuOptions.splice(idx, 0, { + content: typeof config.name == "function" ? config.name(node) : config.name, + has_submenu: !!(subMenuOptions === null || subMenuOptions === void 0 ? void 0 : subMenuOptions.length), + isRgthree: true, + callback: (value, _options, event, parentMenu, _node) => { + if (!!(subMenuOptions === null || subMenuOptions === void 0 ? void 0 : subMenuOptions.length)) { + new LiteGraph.ContextMenu(subMenuOptions.map((option) => (option ? { content: option } : null)), { + event, + parentMenu, + callback: (subValue, _options, _event, _parentMenu, _node) => { + if (config.property) { + node.properties = node.properties || {}; + node.properties[config.property] = config.prepareValue + ? config.prepareValue(subValue.content || '', node) + : subValue.content || ''; + } + config.callback && config.callback(node, subValue === null || subValue === void 0 ? void 0 : subValue.content); + }, + }); + return; + } + if (config.property) { + node.properties = node.properties || {}; + node.properties[config.property] = config.prepareValue + ? config.prepareValue(node.properties[config.property], node) + : !node.properties[config.property]; + } + config.callback && config.callback(node, value === null || value === void 0 ? void 0 : value.content); + }, + }); +} +export function addConnectionLayoutSupport(node, app, options = [ + ["Left", "Right"], + ["Right", "Left"], +], callback) { + addMenuItem(node, app, { + name: "Connections Layout", + property: "connections_layout", + subMenuOptions: options.map((option) => option[0] + (option[1] ? " -> " + option[1] : "")), + prepareValue: (value, node) => { + var _a; + const values = value.split(" -> "); + if (!values[1] && !((_a = node.outputs) === null || _a === void 0 ? void 0 : _a.length)) { + values[1] = LAYOUT_LABEL_OPPOSITES[values[0]]; + } + if (!LAYOUT_LABEL_TO_DATA[values[0]] || !LAYOUT_LABEL_TO_DATA[values[1]]) { + throw new Error(`New Layout invalid: [${values[0]}, ${values[1]}]`); + } + return values; + }, + callback: (node) => { + callback && callback(node); + app.graph.setDirtyCanvas(true, true); + }, + }); + node.prototype.getConnectionPos = function (isInput, slotNumber, out) { + return getConnectionPosForLayout(this, isInput, slotNumber, out); + }; +} +export function setConnectionsLayout(node, newLayout) { + var _a; + newLayout = newLayout || node.defaultConnectionsLayout || ["Left", "Right"]; + if (!newLayout[1] && !((_a = node.outputs) === null || _a === void 0 ? void 0 : _a.length)) { + newLayout[1] = LAYOUT_LABEL_OPPOSITES[newLayout[0]]; + } + if (!LAYOUT_LABEL_TO_DATA[newLayout[0]] || !LAYOUT_LABEL_TO_DATA[newLayout[1]]) { + throw new Error(`New Layout invalid: [${newLayout[0]}, ${newLayout[1]}]`); + } + node.properties = node.properties || {}; + node.properties["connections_layout"] = newLayout; +} +export function setConnectionsCollapse(node, collapseConnections = null) { + node.properties = node.properties || {}; + collapseConnections = + collapseConnections !== null ? collapseConnections : !node.properties["collapse_connections"]; + node.properties["collapse_connections"] = collapseConnections; +} +export function getConnectionPosForLayout(node, isInput, slotNumber, out) { + var _a, _b, _c; + out = out || new Float32Array(2); + node.properties = node.properties || {}; + const layout = node.properties["connections_layout"] || + node.defaultConnectionsLayout || ["Left", "Right"]; + const collapseConnections = node.properties["collapse_connections"] || false; + const offset = (_a = node.constructor.layout_slot_offset) !== null && _a !== void 0 ? _a : LiteGraph.NODE_SLOT_HEIGHT * 0.5; + let side = isInput ? layout[0] : layout[1]; + const otherSide = isInput ? layout[1] : layout[0]; + let data = LAYOUT_LABEL_TO_DATA[side]; + const slotList = node[isInput ? "inputs" : "outputs"]; + const cxn = slotList[slotNumber]; + if (!cxn) { + console.log("No connection found.. weird", isInput, slotNumber); + return out; + } + if (cxn.disabled) { + if (cxn.color_on !== "#666665") { + cxn._color_on_org = cxn._color_on_org || cxn.color_on; + cxn._color_off_org = cxn._color_off_org || cxn.color_off; + } + cxn.color_on = "#666665"; + cxn.color_off = "#666665"; + } + else if (cxn.color_on === "#666665") { + cxn.color_on = cxn._color_on_org || undefined; + cxn.color_off = cxn._color_off_org || undefined; + } + const displaySlot = collapseConnections + ? 0 + : slotNumber - + slotList.reduce((count, ioput, index) => { + count += index < slotNumber && ioput.hidden ? 1 : 0; + return count; + }, 0); + cxn.dir = data[0]; + if ((node.size[0] == 10 || node.size[1] == 10) && node.properties["connections_dir"]) { + cxn.dir = node.properties["connections_dir"][isInput ? 0 : 1]; + } + if (side === "Left") { + if (node.flags.collapsed) { + var w = node._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH; + out[0] = node.pos[0]; + out[1] = node.pos[1] - LiteGraph.NODE_TITLE_HEIGHT * 0.5; + } + else { + toggleConnectionLabel(cxn, !isInput || collapseConnections || !!node.hideSlotLabels); + out[0] = node.pos[0] + offset; + if ((_b = node.constructor) === null || _b === void 0 ? void 0 : _b.type.includes("Reroute")) { + out[1] = node.pos[1] + node.size[1] * 0.5; + } + else { + out[1] = + node.pos[1] + + (displaySlot + 0.7) * LiteGraph.NODE_SLOT_HEIGHT + + (node.constructor.slot_start_y || 0); + } + } + } + else if (side === "Right") { + if (node.flags.collapsed) { + var w = node._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH; + out[0] = node.pos[0] + w; + out[1] = node.pos[1] - LiteGraph.NODE_TITLE_HEIGHT * 0.5; + } + else { + toggleConnectionLabel(cxn, isInput || collapseConnections || !!node.hideSlotLabels); + out[0] = node.pos[0] + node.size[0] + 1 - offset; + if ((_c = node.constructor) === null || _c === void 0 ? void 0 : _c.type.includes("Reroute")) { + out[1] = node.pos[1] + node.size[1] * 0.5; + } + else { + out[1] = + node.pos[1] + + (displaySlot + 0.7) * LiteGraph.NODE_SLOT_HEIGHT + + (node.constructor.slot_start_y || 0); + } + } + } + else if (side === "Top") { + if (!cxn.has_old_label) { + cxn.has_old_label = true; + cxn.old_label = cxn.label; + cxn.label = " "; + } + out[0] = node.pos[0] + node.size[0] * 0.5; + out[1] = node.pos[1] + offset; + } + else if (side === "Bottom") { + if (!cxn.has_old_label) { + cxn.has_old_label = true; + cxn.old_label = cxn.label; + cxn.label = " "; + } + out[0] = node.pos[0] + node.size[0] * 0.5; + out[1] = node.pos[1] + node.size[1] - offset; + } + return out; +} +function toggleConnectionLabel(cxn, hide = true) { + if (hide) { + if (!cxn.has_old_label) { + cxn.has_old_label = true; + cxn.old_label = cxn.label; + } + cxn.label = " "; + } + else if (!hide && cxn.has_old_label) { + cxn.has_old_label = false; + cxn.label = cxn.old_label; + cxn.old_label = undefined; + } + return cxn; +} +export function addHelpMenuItem(node, content, menuOptions) { + addMenuItemOnExtraMenuOptions(node, { + name: "🛟 Node Help", + callback: (node) => { + if (node.showHelp) { + node.showHelp(); + } + else { + new RgthreeHelpDialog(node, content).show(); + } + }, + }, menuOptions, "Properties Panel"); +} +export var PassThroughFollowing; +(function (PassThroughFollowing) { + PassThroughFollowing[PassThroughFollowing["ALL"] = 0] = "ALL"; + PassThroughFollowing[PassThroughFollowing["NONE"] = 1] = "NONE"; + PassThroughFollowing[PassThroughFollowing["REROUTE_ONLY"] = 2] = "REROUTE_ONLY"; +})(PassThroughFollowing || (PassThroughFollowing = {})); +export function shouldPassThrough(node, passThroughFollowing = PassThroughFollowing.ALL) { + var _a; + const type = (_a = node === null || node === void 0 ? void 0 : node.constructor) === null || _a === void 0 ? void 0 : _a.type; + if (!type || passThroughFollowing === PassThroughFollowing.NONE) { + return false; + } + if (passThroughFollowing === PassThroughFollowing.REROUTE_ONLY) { + return type.includes("Reroute"); + } + return (type.includes("Reroute") || type.includes("Node Combiner") || type.includes("Node Collector")); +} +function filterOutPassthroughNodes(infos, passThroughFollowing = PassThroughFollowing.ALL) { + return infos.filter((i) => !shouldPassThrough(i.node, passThroughFollowing)); +} +export function getConnectedInputNodes(startNode, currentNode, slot, passThroughFollowing = PassThroughFollowing.ALL) { + return getConnectedNodesInfo(startNode, IoDirection.INPUT, currentNode, slot, passThroughFollowing).map((n) => n.node); +} +export function getConnectedInputInfosAndFilterPassThroughs(startNode, currentNode, slot, passThroughFollowing = PassThroughFollowing.ALL) { + return filterOutPassthroughNodes(getConnectedNodesInfo(startNode, IoDirection.INPUT, currentNode, slot, passThroughFollowing), passThroughFollowing); +} +export function getConnectedInputNodesAndFilterPassThroughs(startNode, currentNode, slot, passThroughFollowing = PassThroughFollowing.ALL) { + return getConnectedInputInfosAndFilterPassThroughs(startNode, currentNode, slot, passThroughFollowing).map(n => n.node); +} +export function getConnectedOutputNodes(startNode, currentNode, slot, passThroughFollowing = PassThroughFollowing.ALL) { + return getConnectedNodesInfo(startNode, IoDirection.OUTPUT, currentNode, slot, passThroughFollowing).map((n) => n.node); +} +export function getConnectedOutputNodesAndFilterPassThroughs(startNode, currentNode, slot, passThroughFollowing = PassThroughFollowing.ALL) { + return filterOutPassthroughNodes(getConnectedNodesInfo(startNode, IoDirection.OUTPUT, currentNode, slot, passThroughFollowing), passThroughFollowing).map(n => n.node); +} +export function getConnectedNodesInfo(startNode, dir = IoDirection.INPUT, currentNode, slot, passThroughFollowing = PassThroughFollowing.ALL, originTravelFromSlot) { + var _a, _b, _c, _d, _e, _f; + currentNode = currentNode || startNode; + let rootNodes = []; + if (startNode === currentNode || shouldPassThrough(currentNode, passThroughFollowing)) { + let linkIds; + slot = slot != null && slot > -1 ? slot : undefined; + if (dir == IoDirection.OUTPUT) { + if (slot != null) { + linkIds = [...(((_b = (_a = currentNode.outputs) === null || _a === void 0 ? void 0 : _a[slot]) === null || _b === void 0 ? void 0 : _b.links) || [])]; + } + else { + linkIds = ((_c = currentNode.outputs) === null || _c === void 0 ? void 0 : _c.flatMap((i) => i.links)) || []; + } + } + else { + if (slot != null) { + linkIds = [(_e = (_d = currentNode.inputs) === null || _d === void 0 ? void 0 : _d[slot]) === null || _e === void 0 ? void 0 : _e.link]; + } + else { + linkIds = ((_f = currentNode.inputs) === null || _f === void 0 ? void 0 : _f.map((i) => i.link)) || []; + } + } + let graph = app.graph; + for (const linkId of linkIds) { + let link = null; + if (typeof linkId == "number") { + link = graph.links[linkId]; + } + if (!link) { + continue; + } + const travelFromSlot = dir == IoDirection.OUTPUT ? link.origin_slot : link.target_slot; + const connectedId = dir == IoDirection.OUTPUT ? link.target_id : link.origin_id; + const travelToSlot = dir == IoDirection.OUTPUT ? link.target_slot : link.origin_slot; + originTravelFromSlot = originTravelFromSlot != null ? originTravelFromSlot : travelFromSlot; + const originNode = graph.getNodeById(connectedId); + if (!link) { + console.error("No connected node found... weird"); + continue; + } + if (rootNodes.some((n) => n.node == originNode)) { + console.log(`${startNode.title} (${startNode.id}) seems to have two links to ${originNode.title} (${originNode.id}). One may be stale: ${linkIds.join(", ")}`); + } + else { + rootNodes.push({ node: originNode, travelFromSlot, travelToSlot, originTravelFromSlot }); + if (shouldPassThrough(originNode, passThroughFollowing)) { + for (const foundNode of getConnectedNodesInfo(startNode, dir, originNode, undefined, undefined, originTravelFromSlot)) { + if (!rootNodes.map((n) => n.node).includes(foundNode.node)) { + rootNodes.push(foundNode); + } + } + } + } + } + } + return rootNodes; +} +export function followConnectionUntilType(node, dir, slotNum, skipSelf = false) { + const slots = dir === IoDirection.OUTPUT ? node.outputs : node.inputs; + if (!slots || !slots.length) { + return null; + } + let type = null; + if (slotNum) { + if (!slots[slotNum]) { + return null; + } + type = getTypeFromSlot(slots[slotNum], dir, skipSelf); + } + else { + for (const slot of slots) { + type = getTypeFromSlot(slot, dir, skipSelf); + if (type) { + break; + } + } + } + return type; +} +function getTypeFromSlot(slot, dir, skipSelf = false) { + let graph = app.graph; + let type = slot === null || slot === void 0 ? void 0 : slot.type; + if (!skipSelf && type != null && type != "*") { + return { type: type, label: slot === null || slot === void 0 ? void 0 : slot.label, name: slot === null || slot === void 0 ? void 0 : slot.name }; + } + const links = getSlotLinks(slot); + for (const link of links) { + const connectedId = dir == IoDirection.OUTPUT ? link.link.target_id : link.link.origin_id; + const connectedSlotNum = dir == IoDirection.OUTPUT ? link.link.target_slot : link.link.origin_slot; + const connectedNode = graph.getNodeById(connectedId); + const connectedSlots = dir === IoDirection.OUTPUT ? connectedNode.inputs : connectedNode.outputs; + let connectedSlot = connectedSlots[connectedSlotNum]; + if ((connectedSlot === null || connectedSlot === void 0 ? void 0 : connectedSlot.type) != null && (connectedSlot === null || connectedSlot === void 0 ? void 0 : connectedSlot.type) != "*") { + return { + type: connectedSlot.type, + label: connectedSlot === null || connectedSlot === void 0 ? void 0 : connectedSlot.label, + name: connectedSlot === null || connectedSlot === void 0 ? void 0 : connectedSlot.name, + }; + } + else if ((connectedSlot === null || connectedSlot === void 0 ? void 0 : connectedSlot.type) == "*") { + return followConnectionUntilType(connectedNode, dir); + } + } + return null; +} +export async function replaceNode(existingNode, typeOrNewNode, inputNameMap) { + const existingCtor = existingNode.constructor; + const newNode = typeof typeOrNewNode === "string" ? LiteGraph.createNode(typeOrNewNode) : typeOrNewNode; + if (existingNode.title != existingCtor.title) { + newNode.title = existingNode.title; + } + newNode.pos = [...existingNode.pos]; + newNode.properties = { ...existingNode.properties }; + const oldComputeSize = [...existingNode.computeSize()]; + const oldSize = [ + existingNode.size[0] === oldComputeSize[0] ? null : existingNode.size[0], + existingNode.size[1] === oldComputeSize[1] ? null : existingNode.size[1], + ]; + let setSizeIters = 0; + const setSizeFn = () => { + const newComputesize = newNode.computeSize(); + newNode.size[0] = Math.max(oldSize[0] || 0, newComputesize[0]); + newNode.size[1] = Math.max(oldSize[1] || 0, newComputesize[1]); + setSizeIters++; + if (setSizeIters > 10) { + requestAnimationFrame(setSizeFn); + } + }; + setSizeFn(); + const links = []; + for (const [index, output] of existingNode.outputs.entries()) { + for (const linkId of output.links || []) { + const link = app.graph.links[linkId]; + if (!link) + continue; + const targetNode = app.graph.getNodeById(link.target_id); + links.push({ node: newNode, slot: output.name, targetNode, targetSlot: link.target_slot }); + } + } + for (const [index, input] of existingNode.inputs.entries()) { + const linkId = input.link; + if (linkId) { + const link = app.graph.links[linkId]; + const originNode = app.graph.getNodeById(link.origin_id); + links.push({ + node: originNode, + slot: link.origin_slot, + targetNode: newNode, + targetSlot: (inputNameMap === null || inputNameMap === void 0 ? void 0 : inputNameMap.has(input.name)) + ? inputNameMap.get(input.name) + : input.name || index, + }); + } + } + app.graph.add(newNode); + await wait(); + for (const link of links) { + link.node.connect(link.slot, link.targetNode, link.targetSlot); + } + await wait(); + app.graph.remove(existingNode); + newNode.size = newNode.computeSize(); + newNode.setDirtyCanvas(true, true); + return newNode; +} +export function getOriginNodeByLink(linkId) { + let node = null; + if (linkId != null) { + const link = app.graph.links[linkId]; + node = (link != null && app.graph.getNodeById(link.origin_id)) || null; + } + return node; +} +export function applyMixins(original, constructors) { + constructors.forEach((baseCtor) => { + Object.getOwnPropertyNames(baseCtor.prototype).forEach((name) => { + Object.defineProperty(original.prototype, name, Object.getOwnPropertyDescriptor(baseCtor.prototype, name) || Object.create(null)); + }); + }); +} +export function getSlotLinks(inputOrOutput) { + var _a; + const links = []; + if (!inputOrOutput) { + return links; + } + if ((_a = inputOrOutput.links) === null || _a === void 0 ? void 0 : _a.length) { + const output = inputOrOutput; + for (const linkId of output.links || []) { + const link = app.graph.links[linkId]; + if (link) { + links.push({ id: linkId, link: link }); + } + } + } + if (inputOrOutput.link) { + const input = inputOrOutput; + const link = app.graph.links[input.link]; + if (link) { + links.push({ id: input.link, link: link }); + } + } + return links; +} +export async function matchLocalSlotsToServer(node, direction, serverNodeData) { + var _a, _b, _c; + const serverSlotNames = direction == IoDirection.INPUT + ? Object.keys(((_a = serverNodeData.input) === null || _a === void 0 ? void 0 : _a.optional) || {}) + : serverNodeData.output_name; + const serverSlotTypes = direction == IoDirection.INPUT + ? Object.values(((_b = serverNodeData.input) === null || _b === void 0 ? void 0 : _b.optional) || {}).map((i) => i[0]) + : serverNodeData.output; + const slots = direction == IoDirection.INPUT ? node.inputs : node.outputs; + let firstIndex = slots.findIndex((o, i) => i !== serverSlotNames.indexOf(o.name)); + if (firstIndex > -1) { + const links = {}; + slots.map((slot) => { + var _a; + links[slot.name] = links[slot.name] || []; + (_a = links[slot.name]) === null || _a === void 0 ? void 0 : _a.push(...getSlotLinks(slot)); + }); + for (const [index, serverSlotName] of serverSlotNames.entries()) { + const currentNodeSlot = slots.map((s) => s.name).indexOf(serverSlotName); + if (currentNodeSlot > -1) { + if (currentNodeSlot != index) { + const splicedItem = slots.splice(currentNodeSlot, 1)[0]; + slots.splice(index, 0, splicedItem); + } + } + else if (currentNodeSlot === -1) { + const splicedItem = { + name: serverSlotName, + type: serverSlotTypes[index], + links: [], + }; + slots.splice(index, 0, splicedItem); + } + } + if (slots.length > serverSlotNames.length) { + for (let i = slots.length - 1; i > serverSlotNames.length - 1; i--) { + if (direction == IoDirection.INPUT) { + node.disconnectInput(i); + node.removeInput(i); + } + else { + node.disconnectOutput(i); + node.removeOutput(i); + } + } + } + for (const [name, slotLinks] of Object.entries(links)) { + let currentNodeSlot = slots.map((s) => s.name).indexOf(name); + if (currentNodeSlot > -1) { + for (const linkData of slotLinks) { + if (direction == IoDirection.INPUT) { + linkData.link.target_slot = currentNodeSlot; + } + else { + linkData.link.origin_slot = currentNodeSlot; + const nextNode = app.graph.getNodeById(linkData.link.target_id); + if (nextNode && + ((_c = nextNode.constructor) === null || _c === void 0 ? void 0 : _c.type.includes("Reroute"))) { + nextNode.stabilize && nextNode.stabilize(); + } + } + } + } + } + } +} +export function isValidConnection(ioA, ioB) { + if (!ioA || !ioB) { + return false; + } + const typeA = String(ioA.type); + const typeB = String(ioB.type); + let isValid = LiteGraph.isValidConnection(typeA, typeB); + if (!isValid) { + let areCombos = (typeA.includes(",") && typeB === "COMBO") || (typeA === "COMBO" && typeB.includes(",")); + if (areCombos) { + const nameA = ioA.name.toUpperCase().replace("_NAME", "").replace("CKPT", "MODEL"); + const nameB = ioB.name.toUpperCase().replace("_NAME", "").replace("CKPT", "MODEL"); + isValid = nameA.includes(nameB) || nameB.includes(nameA); + } + } + return isValid; +} +const oldIsValidConnection = LiteGraph.isValidConnection; +LiteGraph.isValidConnection = function (typeA, typeB) { + let isValid = oldIsValidConnection.call(LiteGraph, typeA, typeB); + if (!isValid) { + typeA = String(typeA); + typeB = String(typeB); + let areCombos = (typeA.includes(",") && typeB === "COMBO") || (typeA === "COMBO" && typeB.includes(",")); + isValid = areCombos; + } + return isValid; +}; diff --git a/rgthree-comfy/web/comfyui/utils_canvas.js b/rgthree-comfy/web/comfyui/utils_canvas.js new file mode 100644 index 0000000000000000000000000000000000000000..b8884a58d6f4a4ceb40b884d0fa46f55387ba7cf --- /dev/null +++ b/rgthree-comfy/web/comfyui/utils_canvas.js @@ -0,0 +1,151 @@ +import { app } from "../../scripts/app.js"; +function binarySearch(max, getValue, match) { + let min = 0; + while (min <= max) { + let guess = Math.floor((min + max) / 2); + const compareVal = getValue(guess); + if (compareVal === match) + return guess; + if (compareVal < match) + min = guess + 1; + else + max = guess - 1; + } + return max; +} +export function fitString(ctx, str, maxWidth) { + let width = ctx.measureText(str).width; + const ellipsis = "…"; + const ellipsisWidth = measureText(ctx, ellipsis); + if (width <= maxWidth || width <= ellipsisWidth) { + return str; + } + const index = binarySearch(str.length, (guess) => measureText(ctx, str.substring(0, guess)), maxWidth - ellipsisWidth); + return str.substring(0, index) + ellipsis; +} +export function measureText(ctx, str) { + return ctx.measureText(str).width; +} +export function isLowQuality() { + var _a; + const canvas = app.canvas; + return (((_a = canvas.ds) === null || _a === void 0 ? void 0 : _a.scale) || 1) <= 0.5; +} +export function drawNodeWidget(ctx, options) { + const lowQuality = isLowQuality(); + const data = { + width: options.width, + height: options.height, + posY: options.posY, + lowQuality, + margin: 15, + colorOutline: LiteGraph.WIDGET_OUTLINE_COLOR, + colorBackground: LiteGraph.WIDGET_BGCOLOR, + colorText: LiteGraph.WIDGET_TEXT_COLOR, + colorTextSecondary: LiteGraph.WIDGET_SECONDARY_TEXT_COLOR, + }; + ctx.strokeStyle = options.colorStroke || data.colorOutline; + ctx.fillStyle = options.colorBackground || data.colorBackground; + ctx.beginPath(); + ctx.roundRect(data.margin, data.posY, data.width - data.margin * 2, data.height, lowQuality ? [0] : options.borderRadius ? [options.borderRadius] : [options.height * 0.5]); + ctx.fill(); + if (!lowQuality) { + ctx.stroke(); + } + return data; +} +export function drawRoundedRectangle(ctx, options) { + const lowQuality = isLowQuality(); + options = { ...options }; + ctx.strokeStyle = options.colorStroke || LiteGraph.WIDGET_OUTLINE_COLOR; + ctx.fillStyle = options.colorBackground || LiteGraph.WIDGET_BGCOLOR; + ctx.beginPath(); + ctx.roundRect(options.posX, options.posY, options.width, options.height, lowQuality ? [0] : options.borderRadius ? [options.borderRadius] : [options.height * 0.5]); + ctx.fill(); + !lowQuality && ctx.stroke(); +} +export function drawNumberWidgetPart(ctx, options) { + const arrowWidth = 9; + const arrowHeight = 10; + const innerMargin = 3; + const numberWidth = 32; + const xBoundsArrowLess = [0, 0]; + const xBoundsNumber = [0, 0]; + const xBoundsArrowMore = [0, 0]; + ctx.save(); + let posX = options.posX; + const { posY, height, value, textColor } = options; + const midY = posY + height / 2; + if (options.direction === -1) { + posX = posX - arrowWidth - innerMargin - numberWidth - innerMargin - arrowWidth; + } + ctx.fill(new Path2D(`M ${posX} ${midY} l ${arrowWidth} ${arrowHeight / 2} l 0 -${arrowHeight} L ${posX} ${midY} z`)); + xBoundsArrowLess[0] = posX; + xBoundsArrowLess[1] = arrowWidth; + posX += arrowWidth + innerMargin; + ctx.textAlign = "center"; + ctx.textBaseline = "middle"; + const oldTextcolor = ctx.fillStyle; + if (textColor) { + ctx.fillStyle = textColor; + } + ctx.fillText(fitString(ctx, value.toFixed(2), numberWidth), posX + numberWidth / 2, midY); + ctx.fillStyle = oldTextcolor; + xBoundsNumber[0] = posX; + xBoundsNumber[1] = numberWidth; + posX += numberWidth + innerMargin; + ctx.fill(new Path2D(`M ${posX} ${midY - arrowHeight / 2} l ${arrowWidth} ${arrowHeight / 2} l -${arrowWidth} ${arrowHeight / 2} v -${arrowHeight} z`)); + xBoundsArrowMore[0] = posX; + xBoundsArrowMore[1] = arrowWidth; + ctx.restore(); + return [xBoundsArrowLess, xBoundsNumber, xBoundsArrowMore]; +} +drawNumberWidgetPart.WIDTH_TOTAL = 9 + 3 + 32 + 3 + 9; +export function drawTogglePart(ctx, options) { + const lowQuality = isLowQuality(); + ctx.save(); + const { posX, posY, height, value } = options; + const toggleRadius = height * 0.36; + const toggleBgWidth = height * 1.5; + if (!lowQuality) { + ctx.beginPath(); + ctx.roundRect(posX + 4, posY + 4, toggleBgWidth - 8, height - 8, [height * 0.5]); + ctx.globalAlpha = app.canvas.editor_alpha * 0.25; + ctx.fillStyle = "rgba(255,255,255,0.45)"; + ctx.fill(); + ctx.globalAlpha = app.canvas.editor_alpha; + } + ctx.fillStyle = value === true ? "#89B" : "#888"; + const toggleX = lowQuality || value === false + ? posX + height * 0.5 + : value === true + ? posX + height + : posX + height * 0.75; + ctx.beginPath(); + ctx.arc(toggleX, posY + height * 0.5, toggleRadius, 0, Math.PI * 2); + ctx.fill(); + ctx.restore(); + return [posX, toggleBgWidth]; +} +export function drawInfoIcon(ctx, x, y, size = 12) { + ctx.save(); + ctx.beginPath(); + ctx.roundRect(x, y, size, size, [size * 0.1]); + ctx.fillStyle = "#2f82ec"; + ctx.strokeStyle = "#0f2a5e"; + ctx.fill(); + ctx.strokeStyle = "#FFF"; + ctx.lineWidth = 2; + const midX = x + size / 2; + const serifSize = size * 0.175; + ctx.stroke(new Path2D(` + M ${midX} ${y + size * 0.15} + v 2 + M ${midX - serifSize} ${y + size * 0.45} + h ${serifSize} + v ${size * 0.325} + h ${serifSize} + h -${serifSize * 2} + `)); + ctx.restore(); +} diff --git a/rgthree-comfy/web/comfyui/utils_inputs_outputs.js b/rgthree-comfy/web/comfyui/utils_inputs_outputs.js new file mode 100644 index 0000000000000000000000000000000000000000..52fdd15cf078c52b1b2999e92c4bff0bd2914648 --- /dev/null +++ b/rgthree-comfy/web/comfyui/utils_inputs_outputs.js @@ -0,0 +1,12 @@ +export function removeUnusedInputsFromEnd(node, minNumber = 1, nameMatch) { + var _a; + for (let i = node.inputs.length - 1; i >= minNumber; i--) { + if (!((_a = node.inputs[i]) === null || _a === void 0 ? void 0 : _a.link)) { + if (!nameMatch || nameMatch.test(node.inputs[i].name)) { + node.removeInput(i); + } + continue; + } + break; + } +} diff --git a/rgthree-comfy/web/comfyui/utils_menu.js b/rgthree-comfy/web/comfyui/utils_menu.js new file mode 100644 index 0000000000000000000000000000000000000000..a94ae142891630faebce903fe2a8abf417b467b5 --- /dev/null +++ b/rgthree-comfy/web/comfyui/utils_menu.js @@ -0,0 +1,55 @@ +import { app } from "../../scripts/app.js"; +import { rgthreeApi } from "../../rgthree/common/rgthree_api.js"; +const PASS_THROUGH = function (item) { + return item; +}; +export async function showLoraChooser(event, callback, parentMenu, loras) { + var _a, _b; + const canvas = app.canvas; + if (!loras) { + loras = ["None", ...(await rgthreeApi.getLoras())]; + } + new LiteGraph.ContextMenu(loras, { + event: event, + parentMenu, + title: "Choose a lora", + scale: Math.max(1, (_b = (_a = canvas.ds) === null || _a === void 0 ? void 0 : _a.scale) !== null && _b !== void 0 ? _b : 1), + className: "dark", + callback, + }); +} +export function showNodesChooser(event, mapFn, callback, parentMenu) { + var _a, _b; + const canvas = app.canvas; + const nodesOptions = app.graph._nodes + .map(mapFn) + .filter((e) => e != null); + nodesOptions.sort((a, b) => { + return a.value - b.value; + }); + new LiteGraph.ContextMenu(nodesOptions, { + event: event, + parentMenu, + title: "Choose a node id", + scale: Math.max(1, (_b = (_a = canvas.ds) === null || _a === void 0 ? void 0 : _a.scale) !== null && _b !== void 0 ? _b : 1), + className: "dark", + callback, + }); +} +export function showWidgetsChooser(event, node, mapFn, callback, parentMenu) { + var _a, _b; + const options = (node.widgets || []) + .map(mapFn) + .filter((e) => e != null); + if (options.length) { + const canvas = app.canvas; + new LiteGraph.ContextMenu(options, { + event, + parentMenu, + title: "Choose an input/widget", + scale: Math.max(1, (_b = (_a = canvas.ds) === null || _a === void 0 ? void 0 : _a.scale) !== null && _b !== void 0 ? _b : 1), + className: "dark", + callback, + }); + } +} diff --git a/rgthree-comfy/web/comfyui/utils_widgets.js b/rgthree-comfy/web/comfyui/utils_widgets.js new file mode 100644 index 0000000000000000000000000000000000000000..81bed18c60150940442cb15edf38f51219ea6bdb --- /dev/null +++ b/rgthree-comfy/web/comfyui/utils_widgets.js @@ -0,0 +1,265 @@ +import { app } from "../../scripts/app.js"; +import { drawNodeWidget, drawRoundedRectangle, fitString, isLowQuality } from "./utils_canvas.js"; +export function drawLabelAndValue(ctx, label, value, width, posY, height, options) { + var _a; + const outerMargin = 15; + const innerMargin = 10; + const midY = posY + height / 2; + ctx.save(); + ctx.textAlign = "left"; + ctx.textBaseline = "middle"; + ctx.fillStyle = LiteGraph.WIDGET_SECONDARY_TEXT_COLOR; + const labelX = outerMargin + innerMargin + ((_a = options === null || options === void 0 ? void 0 : options.offsetLeft) !== null && _a !== void 0 ? _a : 0); + ctx.fillText(label, labelX, midY); + const valueXLeft = labelX + ctx.measureText(label).width + 7; + const valueXRight = width - (outerMargin + innerMargin); + ctx.fillStyle = LiteGraph.WIDGET_TEXT_COLOR; + ctx.textAlign = "right"; + ctx.fillText(fitString(ctx, value, valueXRight - valueXLeft), valueXRight, midY); + ctx.restore(); +} +export class RgthreeBaseWidget { + constructor(name) { + this.last_y = 0; + this.mouseDowned = null; + this.isMouseDownedAndOver = false; + this.hitAreas = {}; + this.downedHitAreasForMove = []; + this.name = name; + } + clickWasWithinBounds(pos, bounds) { + let xStart = bounds[0]; + let xEnd = xStart + (bounds.length > 2 ? bounds[2] : bounds[1]); + const clickedX = pos[0] >= xStart && pos[0] <= xEnd; + if (bounds.length === 2) { + return clickedX; + } + return clickedX && pos[1] >= bounds[1] && pos[1] <= bounds[1] + bounds[3]; + } + mouse(event, pos, node) { + var _a, _b, _c; + const canvas = app.canvas; + if (event.type == "pointerdown") { + this.mouseDowned = [...pos]; + this.isMouseDownedAndOver = true; + this.downedHitAreasForMove.length = 0; + let anyHandled = false; + for (const part of Object.values(this.hitAreas)) { + if ((part.onDown || part.onMove) && this.clickWasWithinBounds(pos, part.bounds)) { + if (part.onMove) { + this.downedHitAreasForMove.push(part); + } + if (part.onDown) { + const thisHandled = part.onDown.apply(this, [event, pos, node, part]); + anyHandled = anyHandled || thisHandled == true; + } + } + } + return (_a = this.onMouseDown(event, pos, node)) !== null && _a !== void 0 ? _a : anyHandled; + } + if (event.type == "pointerup") { + if (!this.mouseDowned) + return true; + this.downedHitAreasForMove.length = 0; + this.cancelMouseDown(); + let anyHandled = false; + for (const part of Object.values(this.hitAreas)) { + if (part.onUp && this.clickWasWithinBounds(pos, part.bounds)) { + const thisHandled = part.onUp.apply(this, [event, pos, node, part]); + anyHandled = anyHandled || thisHandled == true; + } + } + return (_b = this.onMouseUp(event, pos, node)) !== null && _b !== void 0 ? _b : anyHandled; + } + if (event.type == "pointermove") { + this.isMouseDownedAndOver = !!this.mouseDowned; + if (this.mouseDowned && + (pos[0] < 15 || + pos[0] > node.size[0] - 15 || + pos[1] < this.last_y || + pos[1] > this.last_y + LiteGraph.NODE_WIDGET_HEIGHT)) { + this.isMouseDownedAndOver = false; + } + for (const part of this.downedHitAreasForMove) { + part.onMove.apply(this, [event, pos, node, part]); + } + return (_c = this.onMouseMove(event, pos, node)) !== null && _c !== void 0 ? _c : true; + } + return false; + } + cancelMouseDown() { + this.mouseDowned = null; + this.isMouseDownedAndOver = false; + this.downedHitAreasForMove.length = 0; + } + onMouseDown(event, pos, node) { + return; + } + onMouseUp(event, pos, node) { + return; + } + onMouseMove(event, pos, node) { + return; + } +} +export class RgthreeBetterButtonWidget extends RgthreeBaseWidget { + constructor(name, mouseUpCallback) { + super(name); + this.value = ""; + this.mouseUpCallback = mouseUpCallback; + } + draw(ctx, node, width, y, height) { + drawWidgetButton({ ctx, node, width, height, y }, this.name, this.isMouseDownedAndOver); + } + onMouseUp(event, pos, node) { + return this.mouseUpCallback(event, pos, node); + } +} +export class RgthreeBetterTextWidget { + constructor(name, value) { + this.name = name; + this.value = value; + } + draw(ctx, node, width, y, height) { + const widgetData = drawNodeWidget(ctx, { width, height, posY: y }); + if (!widgetData.lowQuality) { + drawLabelAndValue(ctx, this.name, this.value, width, y, height); + } + } + mouse(event, pos, node) { + const canvas = app.canvas; + if (event.type == "pointerdown") { + canvas.prompt("Label", this.value, (v) => (this.value = v), event); + return true; + } + return false; + } +} +export class RgthreeDividerWidget { + constructor(widgetOptions) { + this.options = { serialize: false }; + this.value = null; + this.name = "divider"; + this.widgetOptions = { + marginTop: 7, + marginBottom: 7, + marginLeft: 15, + marginRight: 15, + color: LiteGraph.WIDGET_OUTLINE_COLOR, + thickness: 1, + }; + Object.assign(this.widgetOptions, widgetOptions || {}); + } + draw(ctx, node, width, posY, h) { + if (this.widgetOptions.thickness) { + ctx.strokeStyle = this.widgetOptions.color; + const x = this.widgetOptions.marginLeft; + const y = posY + this.widgetOptions.marginTop; + const w = width - this.widgetOptions.marginLeft - this.widgetOptions.marginRight; + ctx.stroke(new Path2D(`M ${x} ${y} h ${w}`)); + } + } + computeSize(width) { + return [ + width, + this.widgetOptions.marginTop + this.widgetOptions.marginBottom + this.widgetOptions.thickness, + ]; + } +} +export class RgthreeLabelWidget { + constructor(name, widgetOptions) { + this.options = { serialize: false }; + this.value = null; + this.widgetOptions = {}; + this.posY = 0; + this.name = name; + Object.assign(this.widgetOptions, widgetOptions); + } + draw(ctx, node, width, posY, height) { + this.posY = posY; + ctx.save(); + ctx.textAlign = this.widgetOptions.align || "left"; + ctx.fillStyle = this.widgetOptions.color || LiteGraph.WIDGET_TEXT_COLOR; + const oldFont = ctx.font; + if (this.widgetOptions.italic) { + ctx.font = "italic " + ctx.font; + } + if (this.widgetOptions.size) { + ctx.font = ctx.font.replace(/\d+px/, `${this.widgetOptions.size}px`); + } + const midY = posY + height / 2; + ctx.textBaseline = "middle"; + if (this.widgetOptions.align === "center") { + ctx.fillText(this.name, node.size[0] / 2, midY); + } + else { + ctx.fillText(this.name, 15, midY); + } + ctx.font = oldFont; + if (this.widgetOptions.actionLabel === "__PLUS_ICON__") { + const plus = new Path2D(`M${node.size[0] - 15 - 2} ${posY + 7} v4 h-4 v4 h-4 v-4 h-4 v-4 h4 v-4 h4 v4 h4 z`); + ctx.lineJoin = "round"; + ctx.lineCap = "round"; + ctx.fillStyle = "#3a3"; + ctx.strokeStyle = "#383"; + ctx.fill(plus); + ctx.stroke(plus); + } + ctx.restore(); + } + mouse(event, nodePos, node) { + if (event.type !== "pointerdown" || + isLowQuality() || + !this.widgetOptions.actionLabel || + !this.widgetOptions.actionCallback) { + return false; + } + const pos = [nodePos[0], nodePos[1] - this.posY]; + const rightX = node.size[0] - 15; + if (pos[0] > rightX || pos[0] < rightX - 16) { + return false; + } + this.widgetOptions.actionCallback(event); + return true; + } +} +export class RgthreeInvisibleWidget { + constructor(name, type, value, serializeValueFn) { + this.serializeValue = undefined; + this.name = name; + this.type = type; + this.value = value; + if (serializeValueFn) { + this.serializeValue = serializeValueFn; + } + } + draw() { return; } + computeSize(width) { return [0, 0]; } +} +export function drawWidgetButton(drawCtx, text, isMouseDownedAndOver = false) { + if (!isLowQuality() && !isMouseDownedAndOver) { + drawRoundedRectangle(drawCtx.ctx, { + width: drawCtx.width - 30 - 2, + height: drawCtx.height, + posY: drawCtx.y + 1, + posX: 15 + 1, + borderRadius: 4, + colorBackground: "#000000aa", + colorStroke: "#000000aa", + }); + } + drawRoundedRectangle(drawCtx.ctx, { + width: drawCtx.width - 30, + height: drawCtx.height, + posY: drawCtx.y + (isMouseDownedAndOver ? 1 : 0), + posX: 15, + borderRadius: isLowQuality() ? 0 : 4, + colorBackground: isMouseDownedAndOver ? "#444" : LiteGraph.WIDGET_BGCOLOR, + }); + if (!isLowQuality()) { + drawCtx.ctx.textBaseline = "middle"; + drawCtx.ctx.textAlign = "center"; + drawCtx.ctx.fillStyle = LiteGraph.WIDGET_TEXT_COLOR; + drawCtx.ctx.fillText(text, drawCtx.node.size[0] / 2, drawCtx.y + drawCtx.height / 2 + (isMouseDownedAndOver ? 1 : 0)); + } +} diff --git a/rgthree-comfy/web/common/css/buttons.css b/rgthree-comfy/web/common/css/buttons.css new file mode 100644 index 0000000000000000000000000000000000000000..d74ebaab92257390a56a7ddbe71f93721290bc01 --- /dev/null +++ b/rgthree-comfy/web/common/css/buttons.css @@ -0,0 +1,90 @@ +:not(#fakeid) .rgthree-button-reset { + position: relative; + appearance: none; + cursor: pointer; + border: 0; + background: transparent; + color: inherit; + padding: 0; + margin: 0; +} + +:not(#fakeid) .rgthree-button { + --padding-top: 7px; + --padding-bottom: 9px; + --padding-x: 16px; + position: relative; + cursor: pointer; + border: 0; + border-radius: 0.25rem; + background: rgba(0, 0, 0, 0.5); + color: white; + font-family: system-ui, sans-serif; + font-size: 1rem; + line-height: 1; + white-space: nowrap; + text-decoration: none; + margin: 0.25rem; + box-shadow: 0px 0px 2px rgb(0, 0, 0); + background: #212121; + transition: all 0.1s ease-in-out; + padding: var(--padding-top) var(--padding-x) var(--padding-bottom); + display: inline-flex; + flex-direction: row; + align-items: center; + justify-content: center; +} +:not(#fakeid) .rgthree-button::before, :not(#fakeid) .rgthree-button::after { + content: ""; + display: block; + position: absolute; + border-radius: 0.25rem; + left: 0; + top: 0; + width: 100%; + height: 100%; + box-shadow: inset 1px 1px 0px rgba(255, 255, 255, 0.12), inset -1px -1px 0px rgba(0, 0, 0, 0.75); + background: linear-gradient(to bottom, rgba(255, 255, 255, 0.06), rgba(0, 0, 0, 0.15)); + mix-blend-mode: screen; +} +:not(#fakeid) .rgthree-button::after { + mix-blend-mode: multiply; +} +:not(#fakeid) .rgthree-button:hover { + background: #303030; +} +:not(#fakeid) .rgthree-button:active { + box-shadow: 0px 0px 0px rgba(0, 0, 0, 0); + background: #121212; + padding: calc(var(--padding-top) + 1px) calc(var(--padding-x) - 1px) calc(var(--padding-bottom) - 1px) calc(var(--padding-x) + 1px); +} +:not(#fakeid) .rgthree-button:active::before, :not(#fakeid) .rgthree-button:active::after { + box-shadow: 1px 1px 0px rgba(255, 255, 255, 0.15), inset 1px 1px 0px rgba(0, 0, 0, 0.5), inset 1px 3px 5px rgba(0, 0, 0, 0.33); +} +:not(#fakeid) .rgthree-button.-blue { + background: #346599 !important; +} +:not(#fakeid) .rgthree-button.-blue:hover { + background: #3b77b8 !important; +} +:not(#fakeid) .rgthree-button.-blue:active { + background: #1d5086 !important; +} +:not(#fakeid) .rgthree-button.-green { + background: linear-gradient(to bottom, rgba(255, 255, 255, 0.06), rgba(0, 0, 0, 0.15)), #14580b; +} +:not(#fakeid) .rgthree-button.-green:hover { + background: linear-gradient(to bottom, rgba(255, 255, 255, 0.06), rgba(0, 0, 0, 0.15)), #1a6d0f; +} +:not(#fakeid) .rgthree-button.-green:active { + background: linear-gradient(to bottom, rgba(0, 0, 0, 0.15), rgba(255, 255, 255, 0.06)), #0f3f09; +} +:not(#fakeid) .rgthree-button[disabled] { + box-shadow: none; + background: #666 !important; + color: #aaa; + pointer-events: none; +} +:not(#fakeid) .rgthree-button[disabled]::before, :not(#fakeid) .rgthree-button[disabled]::after { + display: none; +} diff --git a/rgthree-comfy/web/common/css/dialog.css b/rgthree-comfy/web/common/css/dialog.css new file mode 100644 index 0000000000000000000000000000000000000000..c9efdf67c5848facf4589d308ed53e300d67cdd9 --- /dev/null +++ b/rgthree-comfy/web/common/css/dialog.css @@ -0,0 +1,124 @@ +@charset "UTF-8"; +.rgthree-dialog { + outline: 0; + border: 0; + border-radius: 6px; + background: #414141; + color: #fff; + box-shadow: inset 1px 1px 0px rgba(255, 255, 255, 0.05), inset -1px -1px 0px rgba(0, 0, 0, 0.5), 2px 2px 20px rgb(0, 0, 0); + max-width: 800px; + box-sizing: border-box; + font-family: "Segoe UI", Tahoma, Geneva, Verdana, sans-serif; + font-size: 1rem; + padding: 0; + max-height: calc(100% - 32px); +} +.rgthree-dialog *, .rgthree-dialog *::before, .rgthree-dialog *::after { + box-sizing: inherit; +} + +.rgthree-dialog-container > * { + padding: 8px 16px; +} +.rgthree-dialog-container > *:first-child { + padding-top: 16px; +} +.rgthree-dialog-container > *:last-child { + padding-bottom: 16px; +} + +.rgthree-dialog.-iconed::after { + content: ""; + font-size: 276px; + position: absolute; + right: 0px; + bottom: 0px; + opacity: 0.15; + display: block; + width: 237px; + overflow: hidden; + height: 186px; + line-height: 1; + pointer-events: none; + z-index: -1; +} + +.rgthree-dialog.-iconed.-help::after { + content: "🛟"; +} + +.rgthree-dialog.-iconed.-settings::after { + content: "⚙️"; +} + +@media (max-width: 832px) { + .rgthree-dialog { + max-width: calc(100% - 32px); + } +} +.rgthree-dialog-container-title { + display: flex; + flex-direction: row; + align-items: center; + justify-content: start; +} + +.rgthree-dialog-container-title > svg:first-child { + width: 36px; + height: 36px; + margin-right: 16px; +} + +.rgthree-dialog-container-title h2 { + font-size: 1.375rem; + margin: 0; + font-weight: bold; +} + +.rgthree-dialog-container-title h2 small { + font-size: 0.8125rem; + font-weight: normal; + opacity: 0.75; +} + +.rgthree-dialog-container-content { + overflow: auto; + max-height: calc(100vh - 200px); /* Arbitrary height to copensate for margin, title, and footer.*/ +} + +.rgthree-dialog-container-content p { + font-size: 0.8125rem; + margin-top: 0; +} + +.rgthree-dialog-container-content ul li p { + margin-bottom: 4px; +} + +.rgthree-dialog-container-content ul li p + p { + margin-top: 0.5em; +} + +.rgthree-dialog-container-content ul li ul { + margin-top: 0.5em; + margin-bottom: 1em; +} + +.rgthree-dialog-container-content p code { + display: inline-block; + padding: 2px 4px; + margin: 0px 2px; + border: 1px solid rgba(255, 255, 255, 0.25); + border-radius: 3px; + background: rgba(255, 255, 255, 0.1); +} + +.rgthree-dialog-container-footer { + display: flex; + align-items: center; + justify-content: center; +} + +body.rgthree-dialog-open > *:not(.rgthree-dialog):not(.rgthree-top-messages-container) { + filter: blur(5px); +} diff --git a/rgthree-comfy/web/common/css/dialog_lora_chooser.css b/rgthree-comfy/web/common/css/dialog_lora_chooser.css new file mode 100644 index 0000000000000000000000000000000000000000..0142ab10a522512a0cc8dbdb11ce5793baa41f11 --- /dev/null +++ b/rgthree-comfy/web/common/css/dialog_lora_chooser.css @@ -0,0 +1,143 @@ +.rgthree-lora-chooser-dialog { + max-width: 100%; +} +.rgthree-lora-chooser-dialog .rgthree-dialog-container-title { + display: flex; + flex-direction: column; +} +.rgthree-lora-chooser-dialog .rgthree-dialog-container-title h2 { + display: flex; + width: 100%; +} +.rgthree-lora-chooser-dialog .rgthree-lora-chooser-search { + margin-left: auto; + border-radius: 50px; + width: 50%; + max-width: 170px; + padding: 2px 8px; +} +.rgthree-lora-chooser-dialog .rgthree-lora-chooser-header { + display: flex; + flex-direction: row; +} +.rgthree-lora-chooser-dialog .rgthree-lora-filters-container svg { + width: 16px; + height: 16px; +} +.rgthree-lora-chooser-dialog .rgthree-dialog-container-content { + width: 80vw; + height: 80vh; +} +.rgthree-lora-chooser-dialog .rgthree-button-reset { + width: 32px; + height: 32px; +} +.rgthree-lora-chooser-dialog .rgthree-button-reset > svg { + width: 100%; + height: 100%; +} +.rgthree-lora-chooser-dialog ul.rgthree-lora-chooser-list { + list-style: none; + margin: 0; + padding: 0; + position: relative; + display: flex; + flex-direction: row; + flex-wrap: wrap; + align-items: start; + justify-content: space-around; +} +.rgthree-lora-chooser-dialog ul.rgthree-lora-chooser-list > li { + position: relative; + flex: 0 0 auto; + width: 170px; + max-width: 100%; + margin: 8px 8px 16px; +} +.rgthree-lora-chooser-dialog ul.rgthree-lora-chooser-list > li label { + position: absolute; + display: block; + inset: 0; + z-index: 3; + cursor: pointer; +} +.rgthree-lora-chooser-dialog ul.rgthree-lora-chooser-list > li input[type=checkbox] { + position: absolute; + right: 8px; + top: 8px; + margin: 0; + z-index: 2; + appearance: none; + background-color: #fff; + width: 48px; + height: 48px; + border-radius: 4px; + border: 1px solid rgb(120, 120, 120); + opacity: 0; + transition: opacity 0.15s ease-in-out; +} +.rgthree-lora-chooser-dialog ul.rgthree-lora-chooser-list > li input[type=checkbox]:checked { + opacity: 1; + background: #0060df; +} +.rgthree-lora-chooser-dialog ul.rgthree-lora-chooser-list > li input[type=checkbox]:checked::before { + content: ""; + display: block; + width: 100%; + height: 100%; + box-shadow: inset 100px 100px #fff; + clip-path: polygon(40.13% 68.39%, 23.05% 51.31%, 17.83% 48.26%, 12.61% 49.57%, 9.57% 53.04%, 8% 60%, 34.13% 85.87%, 39.82% 89.57%, 45.88% 86.73%, 90.66% 32.39%, 88.92% 26.1%, 83.03% 22.17%, 76.94% 22.62%); +} +.rgthree-lora-chooser-dialog ul.rgthree-lora-chooser-list > li figure { + position: relative; + display: block; + margin: 0 0 8px; + padding: 0; + border: 1px solid rgba(120, 120, 120, 0.8); + background: rgba(120, 120, 120, 0.5); + width: 100%; + padding-top: 120%; + transition: box-shadow 0.15s ease-in-out; + opacity: 0.75; +} +.rgthree-lora-chooser-dialog ul.rgthree-lora-chooser-list > li figure::after { + content: ""; + display: block; + position: absolute; + inset: 0; +} +.rgthree-lora-chooser-dialog ul.rgthree-lora-chooser-list > li figure:empty::before { + content: "No image."; + color: rgba(200, 200, 200, 0.8); + position: absolute; + display: block; + inset: 0; + font-size: 1.2em; + text-align: center; + display: flex; + align-items: center; + justify-content: center; +} +.rgthree-lora-chooser-dialog ul.rgthree-lora-chooser-list > li figure > img, .rgthree-lora-chooser-dialog ul.rgthree-lora-chooser-list > li figure > video { + position: absolute; + width: 100%; + height: 100%; + top: 0; + left: 0; + object-fit: cover; +} +.rgthree-lora-chooser-dialog ul.rgthree-lora-chooser-list > li div { + word-wrap: break-word; + font-size: 0.8rem; + opacity: 0.75; +} +.rgthree-lora-chooser-dialog ul.rgthree-lora-chooser-list > li:hover figure::after { + box-shadow: 0px 2px 6px rgba(0, 0, 0, 0.75); +} +.rgthree-lora-chooser-dialog ul.rgthree-lora-chooser-list > li :checked ~ figure::after { + box-shadow: 0 0 5px #fff, 0px 0px 15px rgba(49, 131, 255, 0.88), inset 0 0 3px #fff, inset 0px 0px 5px rgba(49, 131, 255, 0.88); +} +.rgthree-lora-chooser-dialog ul.rgthree-lora-chooser-list > li:hover *, .rgthree-lora-chooser-dialog ul.rgthree-lora-chooser-list > li:hover input[type=checkbox], +.rgthree-lora-chooser-dialog ul.rgthree-lora-chooser-list > li :checked ~ * { + opacity: 1; +} diff --git a/rgthree-comfy/web/common/css/dialog_model_info.css b/rgthree-comfy/web/common/css/dialog_model_info.css new file mode 100644 index 0000000000000000000000000000000000000000..83ed75ee4e33b1b8268c32d2757ce14a18dea26e --- /dev/null +++ b/rgthree-comfy/web/common/css/dialog_model_info.css @@ -0,0 +1,333 @@ +.rgthree-info-dialog { + width: 90vw; + max-width: 960px; +} +.rgthree-info-dialog .rgthree-info-area { + list-style: none; + padding: 0; + margin: 0; + display: flex; +} +.rgthree-info-dialog .rgthree-info-area > li { + display: inline-flex; + margin: 0; + vertical-align: top; +} +.rgthree-info-dialog .rgthree-info-area > li + li { + margin-left: 6px; +} +.rgthree-info-dialog .rgthree-info-area > li:not(.-link) + li.-link { + margin-left: auto; +} +.rgthree-info-dialog .rgthree-info-area > li.rgthree-info-tag > * { + min-height: 24px; + border-radius: 4px; + line-height: 1; + color: rgba(255, 255, 255, 0.85); + background: rgb(69, 92, 85); + font-size: 14px; + font-weight: bold; + text-decoration: none; + display: flex; + height: 1.6em; + padding-left: 0.5em; + padding-right: 0.5em; + padding-bottom: 0.1em; + align-content: center; + justify-content: center; + align-items: center; + box-shadow: inset 0px 0px 0 1px rgba(0, 0, 0, 0.5); +} +.rgthree-info-dialog .rgthree-info-area > li.rgthree-info-tag > * > svg { + width: 16px; + height: 16px; +} +.rgthree-info-dialog .rgthree-info-area > li.rgthree-info-tag > * > svg:last-child { + margin-left: 0.5em; +} +.rgthree-info-dialog .rgthree-info-area > li.rgthree-info-tag > *[href] { + box-shadow: inset 0px 1px 0px rgba(255, 255, 255, 0.25), inset 0px -1px 0px rgba(0, 0, 0, 0.66); +} +.rgthree-info-dialog .rgthree-info-area > li.rgthree-info-tag > *:empty { + display: none; +} +.rgthree-info-dialog .rgthree-info-area > li.-type > * { + background: rgb(73, 54, 94); + color: rgb(228, 209, 248); +} +.rgthree-info-dialog .rgthree-info-area > li.rgthree-info-menu { + margin-left: auto; +} +:not(#fakeid) .rgthree-info-dialog .rgthree-info-area > li.rgthree-info-menu .rgthree-button { + margin: 0; + min-height: 24px; + padding: 0 12px; +} +.rgthree-info-dialog .rgthree-info-area > li.rgthree-info-menu svg { + width: 16px; + height: 16px; +} +.rgthree-info-dialog .rgthree-info-table { + border-collapse: collapse; + margin: 16px 0px; + width: 100%; + font-size: 12px; +} +.rgthree-info-dialog .rgthree-info-table tr.editable button { + display: flex; + width: 28px; + height: 28px; + align-items: center; + justify-content: center; +} +.rgthree-info-dialog .rgthree-info-table tr.editable button svg + svg { + display: none; +} +.rgthree-info-dialog .rgthree-info-table tr.editable.-rgthree-editing button svg { + display: none; +} +.rgthree-info-dialog .rgthree-info-table tr.editable.-rgthree-editing button svg + svg { + display: inline-block; +} +.rgthree-info-dialog .rgthree-info-table td { + position: relative; + border: 1px solid rgba(255, 255, 255, 0.25); + padding: 0; + vertical-align: top; +} +.rgthree-info-dialog .rgthree-info-table td:first-child { + background: rgba(255, 255, 255, 0.075); + width: 10px; +} +.rgthree-info-dialog .rgthree-info-table td:first-child > *:first-child { + white-space: nowrap; + padding-right: 32px; +} +.rgthree-info-dialog .rgthree-info-table td:first-child small { + display: block; + margin-top: 2px; + opacity: 0.75; +} +.rgthree-info-dialog .rgthree-info-table td:first-child small > [data-action] { + text-decoration: underline; + cursor: pointer; +} +.rgthree-info-dialog .rgthree-info-table td:first-child small > [data-action]:hover { + text-decoration: none; +} +.rgthree-info-dialog .rgthree-info-table td a, .rgthree-info-dialog .rgthree-info-table td a:hover, .rgthree-info-dialog .rgthree-info-table td a:visited { + color: inherit; +} +.rgthree-info-dialog .rgthree-info-table td svg { + width: 1.3333em; + height: 1.3333em; + vertical-align: -0.285em; +} +.rgthree-info-dialog .rgthree-info-table td svg.logo-civitai { + margin-right: 0.3333em; +} +.rgthree-info-dialog .rgthree-info-table td > *:first-child { + display: block; + padding: 6px 10px; +} +.rgthree-info-dialog .rgthree-info-table td > input, .rgthree-info-dialog .rgthree-info-table td > textarea { + padding: 5px 10px; + border: 0; + box-shadow: inset 1px 1px 5px 0px rgba(0, 0, 0, 0.5); + font: inherit; + appearance: none; + background: #fff; + color: #121212; + resize: vertical; +} +.rgthree-info-dialog .rgthree-info-table td > input:only-child, .rgthree-info-dialog .rgthree-info-table td > textarea:only-child { + width: 100%; +} +:not(#fakeid) .rgthree-info-dialog .rgthree-info-table td .rgthree-button[data-action=fetch-civitai] { + font-size: inherit; + padding: 6px 16px; + margin: 2px; +} +.rgthree-info-dialog .rgthree-info-table tr[data-field-name=userNote] td > span:first-child { + white-space: pre; +} +.rgthree-info-dialog .rgthree-info-table tr.rgthree-info-table-break-row td { + border: 0; + background: transparent; + padding: 12px 4px 4px; + font-size: 1.2em; +} +.rgthree-info-dialog .rgthree-info-table tr.rgthree-info-table-break-row td > small { + font-style: italic; + opacity: 0.66; +} +.rgthree-info-dialog .rgthree-info-table tr.rgthree-info-table-break-row td:empty { + padding: 4px; +} +.rgthree-info-dialog .rgthree-info-table td .-help { + border: 1px solid currentColor; + position: absolute; + right: 5px; + top: 6px; + line-height: 1; + font-size: 11px; + width: 12px; + height: 12px; + border-radius: 8px; + display: flex; + align-content: center; + justify-content: center; + cursor: help; +} +.rgthree-info-dialog .rgthree-info-table td .-help::before { + content: "?"; +} +.rgthree-info-dialog .rgthree-info-table td > ul.rgthree-info-trained-words-list { + list-style: none; + padding: 2px 8px; + margin: 0; + display: flex; + flex-direction: row; + flex-wrap: wrap; + max-height: 15vh; + overflow: auto; +} +.rgthree-info-dialog .rgthree-info-table td > ul.rgthree-info-trained-words-list > li { + display: inline-flex; + margin: 2px; + vertical-align: top; + border-radius: 4px; + line-height: 1; + color: rgba(255, 255, 255, 0.85); + background: rgb(73, 91, 106); + font-size: 1.2em; + font-weight: 600; + text-decoration: none; + display: flex; + height: 1.6em; + align-content: center; + justify-content: center; + align-items: center; + box-shadow: inset 0px 0px 0 1px rgba(0, 0, 0, 0.5); + cursor: pointer; + white-space: nowrap; + max-width: 183px; +} +.rgthree-info-dialog .rgthree-info-table td > ul.rgthree-info-trained-words-list > li:hover { + background: rgb(68, 109, 142); +} +.rgthree-info-dialog .rgthree-info-table td > ul.rgthree-info-trained-words-list > li > svg { + width: auto; + height: 1.2em; +} +.rgthree-info-dialog .rgthree-info-table td > ul.rgthree-info-trained-words-list > li > span { + padding-left: 0.5em; + padding-right: 0.5em; + padding-bottom: 0.1em; + text-overflow: ellipsis; + overflow: hidden; +} +.rgthree-info-dialog .rgthree-info-table td > ul.rgthree-info-trained-words-list > li > small { + align-self: stretch; + display: flex; + align-items: center; + justify-content: center; + padding: 0 0.5em; + background: rgba(0, 0, 0, 0.2); +} +.rgthree-info-dialog .rgthree-info-table td > ul.rgthree-info-trained-words-list > li.-rgthree-is-selected { + background: rgb(42, 126, 193); +} +.rgthree-info-dialog .rgthree-info-images { + list-style: none; + padding: 0; + margin: 0; + scroll-snap-type: x mandatory; + display: flex; + flex-direction: row; + overflow: auto; +} +.rgthree-info-dialog .rgthree-info-images > li { + scroll-snap-align: start; + max-width: 90%; + flex: 0 0 auto; + display: flex; + align-items: center; + justify-content: center; + flex-direction: column; + overflow: hidden; + padding: 0; + margin: 6px; + font-size: 0; + position: relative; +} +.rgthree-info-dialog .rgthree-info-images > li figure { + margin: 0; + position: static; +} +.rgthree-info-dialog .rgthree-info-images > li figure figcaption { + position: absolute; + left: 0; + width: 100%; + bottom: 0; + padding: 12px; + font-size: 12px; + background: rgba(0, 0, 0, 0.85); + opacity: 0; + transform: translateY(50px); + transition: all 0.25s ease-in-out; +} +.rgthree-info-dialog .rgthree-info-images > li figure figcaption > span { + display: inline-block; + padding: 2px 4px; + margin: 2px; + border-radius: 2px; + border: 1px solid rgba(255, 255, 255, 0.2); + word-break: break-word; +} +.rgthree-info-dialog .rgthree-info-images > li figure figcaption > span label { + display: inline; + padding: 0; + margin: 0; + opacity: 0.5; + pointer-events: none; + user-select: none; +} +.rgthree-info-dialog .rgthree-info-images > li figure figcaption > span a { + color: inherit; + text-decoration: underline; +} +.rgthree-info-dialog .rgthree-info-images > li figure figcaption > span a:hover { + text-decoration: none; +} +.rgthree-info-dialog .rgthree-info-images > li figure figcaption > span a svg { + height: 10px; + margin-left: 4px; + fill: currentColor; +} +.rgthree-info-dialog .rgthree-info-images > li figure figcaption:empty { + text-align: center; +} +.rgthree-info-dialog .rgthree-info-images > li figure figcaption:empty::before { + content: "No data."; +} +.rgthree-info-dialog .rgthree-info-images > li:hover figure figcaption { + opacity: 1; + transform: translateY(0px); +} +.rgthree-info-dialog .rgthree-info-images > li .rgthree-info-table { + width: calc(100% - 16px); +} +.rgthree-info-dialog .rgthree-info-civitai-link { + margin: 8px; + color: #eee; +} +.rgthree-info-dialog .rgthree-info-civitai-link a, .rgthree-info-dialog .rgthree-info-civitai-link a:hover, .rgthree-info-dialog .rgthree-info-civitai-link a:visited { + color: inherit; + text-decoration: none; +} +.rgthree-info-dialog .rgthree-info-civitai-link > svg { + width: 16px; + height: 16px; + margin-right: 8px; +} diff --git a/rgthree-comfy/web/common/css/menu.css b/rgthree-comfy/web/common/css/menu.css new file mode 100644 index 0000000000000000000000000000000000000000..97cd66a07aa36fa73b3f1355b582372c7d4c5ec8 --- /dev/null +++ b/rgthree-comfy/web/common/css/menu.css @@ -0,0 +1,91 @@ +.rgthree-menu { + list-style: none; + padding: 0; + margin: 0; + position: fixed; + z-index: 999999; + pointer-events: none; + opacity: 0; + transition: opacity 0.08s ease-in-out; + color: #dde; + background-color: #111; + font-size: 12px; + box-shadow: 0 0 10px black !important; +} +.rgthree-menu > li { + position: relative; + padding: 4px 6px; + z-index: 9999; + white-space: nowrap; +} +.rgthree-menu > li[role=button] { + background-color: var(--comfy-menu-bg) !important; + color: var(--input-text); + cursor: pointer; +} +.rgthree-menu > li[role=button]:hover { + filter: brightness(155%); +} +.rgthree-menu[state^=measuring] { + display: block; + opacity: 0; +} +.rgthree-menu[state=open] { + display: block; + opacity: 1; + pointer-events: all; +} + +.rgthree-top-menu { + box-sizing: border-box; + white-space: nowrap; + background: var(--content-bg); + color: var(--content-fg); + display: flex; + flex-direction: column; +} +.rgthree-top-menu * { + box-sizing: inherit; +} +.rgthree-top-menu menu { + list-style: none; + padding: 0; + margin: 0; +} +.rgthree-top-menu menu > li:not(#fakeid) { + list-style: none; + padding: 0; + margin: 0; +} +.rgthree-top-menu menu > li:not(#fakeid) > button { + cursor: pointer; + padding: 8px 12px 8px 8px; + width: 100%; + text-align: start; + display: flex; + flex-direction: row; + align-items: center; + justify-content: start; +} +.rgthree-top-menu menu > li:not(#fakeid) > button:hover { + background-color: var(--comfy-input-bg); +} +.rgthree-top-menu menu > li:not(#fakeid) > button svg { + height: 16px; + width: auto; + margin-inline-end: 0.6em; +} +.rgthree-top-menu menu > li:not(#fakeid) > button svg.github-star { + fill: rgb(227, 179, 65); +} +.rgthree-top-menu menu > li:not(#fakeid).rgthree-message { + min-height: 32px; +} +.rgthree-top-menu menu > li:not(#fakeid).rgthree-message > span { + padding: 8px 12px; + display: block; + width: 100%; + text-align: center; + font-style: italic; + font-size: 12px; +} diff --git a/rgthree-comfy/web/common/css/pages_base.css b/rgthree-comfy/web/common/css/pages_base.css new file mode 100644 index 0000000000000000000000000000000000000000..49233231999fa37182bb313afe5f9f61389387c6 --- /dev/null +++ b/rgthree-comfy/web/common/css/pages_base.css @@ -0,0 +1,66 @@ +html { + font-size: 100%; + overflow-y: scroll; + -webkit-text-size-adjust: 100%; + -ms-text-size-adjust: 100%; + box-sizing: border-box; +} + +*, *:before, *:after { + box-sizing: inherit; +} + +:root { + --header-height: 56px; + --progress-height: 12px; +} + +button { + all: unset; +} + +.-bevel { + position: relative; +} + +.-bevel::before { + content: ""; + position: absolute; + left: 0; + top: 0; + width: 100%; + height: 100%; + border: 1px solid red; + border-color: rgba(255, 255, 255, 0.15) rgba(255, 255, 255, 0.15) rgba(0, 0, 0, 0.5) rgba(0, 0, 0, 0.5); + z-index: 5; + pointer-events: none; +} + +body { + background: #202020; + font-family: Arial, sans-serif; + font-size: 1rem; + font-weight: 400; + margin: 0; + padding-top: calc(var(--header-height) + var(--progress-height)); + color: #ffffff; + display: flex; + flex-direction: column; + align-items: center; + justify-content: start; +} + +.app-header { + height: var(--header-height); + padding: 0; + position: fixed; + z-index: 99; + top: 0; + left: 0; + width: 100%; + background: #353535; + display: flex; + flex-direction: row; + align-items: center; + justify-content: start; +} diff --git a/rgthree-comfy/web/common/dialog.js b/rgthree-comfy/web/common/dialog.js new file mode 100644 index 0000000000000000000000000000000000000000..6c1a6eadfb830acb6a76459ea908d93d55849f49 --- /dev/null +++ b/rgthree-comfy/web/common/dialog.js @@ -0,0 +1,109 @@ +import { createElement as $el, getClosestOrSelf, setAttributes } from "./utils_dom.js"; +export class RgthreeDialog extends EventTarget { + constructor(options) { + super(); + this.options = options; + let container = $el("div.rgthree-dialog-container"); + this.element = $el("dialog", { + classes: ["rgthree-dialog", options.class || ""], + child: container, + parent: document.body, + events: { + click: (event) => { + if (!this.element.open || + event.target === container || + getClosestOrSelf(event.target, `.rgthree-dialog-container`) === container) { + return; + } + return this.close(); + }, + }, + }); + this.element.addEventListener("close", (event) => { + this.onDialogElementClose(); + }); + this.titleElement = $el("div.rgthree-dialog-container-title", { + parent: container, + children: !options.title + ? null + : options.title instanceof Element || Array.isArray(options.title) + ? options.title + : typeof options.title === "string" + ? !options.title.includes(" { + var _a; + (_a = button.callback) === null || _a === void 0 ? void 0 : _a.call(button, e); + }, + }, + }); + } + if (options.closeButtonLabel !== false) { + $el("button", { + text: options.closeButtonLabel || "Close", + className: "rgthree-button", + parent: footerEl, + events: { + click: (e) => { + this.close(e); + }, + }, + }); + } + } + setTitle(content) { + const title = typeof content !== "string" || content.includes("by rgthree"); + const options = Object.assign({}, opts, { + class: "-iconed -help", + title, + content, + }); + super(options); + } +} diff --git a/rgthree-comfy/web/common/link_fixer.js b/rgthree-comfy/web/common/link_fixer.js new file mode 100644 index 0000000000000000000000000000000000000000..5df0b983e237fd1e28d101927ae7b9401b0420a5 --- /dev/null +++ b/rgthree-comfy/web/common/link_fixer.js @@ -0,0 +1,288 @@ +var IoDirection; +(function (IoDirection) { + IoDirection[IoDirection["INPUT"] = 0] = "INPUT"; + IoDirection[IoDirection["OUTPUT"] = 1] = "OUTPUT"; +})(IoDirection || (IoDirection = {})); +function getNodeById(graph, id) { + if (graph.getNodeById) { + return graph.getNodeById(id); + } + graph = graph; + return graph.nodes.find((n) => n.id === id); +} +function extendLink(link) { + return { + link: link, + id: link[0], + origin_id: link[1], + origin_slot: link[2], + target_id: link[3], + target_slot: link[4], + type: link[5], + }; +} +export function fixBadLinks(graph, fix = false, silent = false, logger = console) { + var _a, _b; + const patchedNodeSlots = {}; + const data = { + patchedNodes: [], + deletedLinks: [], + }; + async function patchNodeSlot(node, ioDir, slot, linkId, op) { + var _a, _b, _c; + patchedNodeSlots[node.id] = patchedNodeSlots[node.id] || {}; + const patchedNode = patchedNodeSlots[node.id]; + if (ioDir == IoDirection.INPUT) { + patchedNode["inputs"] = patchedNode["inputs"] || {}; + if (patchedNode["inputs"][slot] !== undefined) { + !silent && + logger.log(` > Already set ${node.id}.inputs[${slot}] to ${patchedNode["inputs"][slot]} Skipping.`); + return false; + } + let linkIdToSet = op === "REMOVE" ? null : linkId; + patchedNode["inputs"][slot] = linkIdToSet; + if (fix) { + } + } + else { + patchedNode["outputs"] = patchedNode["outputs"] || {}; + patchedNode["outputs"][slot] = patchedNode["outputs"][slot] || { + links: [...(((_b = (_a = node.outputs) === null || _a === void 0 ? void 0 : _a[slot]) === null || _b === void 0 ? void 0 : _b.links) || [])], + changes: {}, + }; + if (patchedNode["outputs"][slot]["changes"][linkId] !== undefined) { + !silent && + logger.log(` > Already set ${node.id}.outputs[${slot}] to ${patchedNode["inputs"][slot]}! Skipping.`); + return false; + } + patchedNode["outputs"][slot]["changes"][linkId] = op; + if (op === "ADD") { + let linkIdIndex = patchedNode["outputs"][slot]["links"].indexOf(linkId); + if (linkIdIndex !== -1) { + !silent && logger.log(` > Hmmm.. asked to add ${linkId} but it is already in list...`); + return false; + } + patchedNode["outputs"][slot]["links"].push(linkId); + if (fix) { + node.outputs = node.outputs || []; + node.outputs[slot] = node.outputs[slot] || {}; + node.outputs[slot].links = node.outputs[slot].links || []; + node.outputs[slot].links.push(linkId); + } + } + else { + let linkIdIndex = patchedNode["outputs"][slot]["links"].indexOf(linkId); + if (linkIdIndex === -1) { + !silent && logger.log(` > Hmmm.. asked to remove ${linkId} but it doesn't exist...`); + return false; + } + patchedNode["outputs"][slot]["links"].splice(linkIdIndex, 1); + if (fix) { + (_c = node.outputs) === null || _c === void 0 ? void 0 : _c[slot].links.splice(linkIdIndex, 1); + } + } + } + data.patchedNodes.push(node); + return true; + } + function nodeHasLinkId(node, ioDir, slot, linkId) { + var _a, _b, _c, _d, _e, _f, _g, _h, _j, _k; + let has = false; + if (ioDir === IoDirection.INPUT) { + let nodeHasIt = ((_b = (_a = node.inputs) === null || _a === void 0 ? void 0 : _a[slot]) === null || _b === void 0 ? void 0 : _b.link) === linkId; + if ((_c = patchedNodeSlots[node.id]) === null || _c === void 0 ? void 0 : _c["inputs"]) { + let patchedHasIt = patchedNodeSlots[node.id]["inputs"][slot] === linkId; + if (fix && nodeHasIt !== patchedHasIt) { + throw Error("Error. Expected node to match patched data."); + } + has = patchedHasIt; + } + else { + has = !!nodeHasIt; + } + } + else { + let nodeHasIt = (_f = (_e = (_d = node.outputs) === null || _d === void 0 ? void 0 : _d[slot]) === null || _e === void 0 ? void 0 : _e.links) === null || _f === void 0 ? void 0 : _f.includes(linkId); + if ((_j = (_h = (_g = patchedNodeSlots[node.id]) === null || _g === void 0 ? void 0 : _g["outputs"]) === null || _h === void 0 ? void 0 : _h[slot]) === null || _j === void 0 ? void 0 : _j["changes"][linkId]) { + let patchedHasIt = (_k = patchedNodeSlots[node.id]["outputs"][slot]) === null || _k === void 0 ? void 0 : _k.links.includes(linkId); + if (fix && nodeHasIt !== patchedHasIt) { + throw Error("Error. Expected node to match patched data."); + } + has = !!patchedHasIt; + } + else { + has = !!nodeHasIt; + } + } + return has; + } + function nodeHasAnyLink(node, ioDir, slot) { + var _a, _b, _c, _d, _e, _f, _g, _h, _j, _k; + let hasAny = false; + if (ioDir === IoDirection.INPUT) { + let nodeHasAny = ((_b = (_a = node.inputs) === null || _a === void 0 ? void 0 : _a[slot]) === null || _b === void 0 ? void 0 : _b.link) != null; + if ((_c = patchedNodeSlots[node.id]) === null || _c === void 0 ? void 0 : _c["inputs"]) { + let patchedHasAny = patchedNodeSlots[node.id]["inputs"][slot] != null; + if (fix && nodeHasAny !== patchedHasAny) { + throw Error("Error. Expected node to match patched data."); + } + hasAny = patchedHasAny; + } + else { + hasAny = !!nodeHasAny; + } + } + else { + let nodeHasAny = (_f = (_e = (_d = node.outputs) === null || _d === void 0 ? void 0 : _d[slot]) === null || _e === void 0 ? void 0 : _e.links) === null || _f === void 0 ? void 0 : _f.length; + if ((_j = (_h = (_g = patchedNodeSlots[node.id]) === null || _g === void 0 ? void 0 : _g["outputs"]) === null || _h === void 0 ? void 0 : _h[slot]) === null || _j === void 0 ? void 0 : _j["changes"]) { + let patchedHasAny = (_k = patchedNodeSlots[node.id]["outputs"][slot]) === null || _k === void 0 ? void 0 : _k.links.length; + if (fix && nodeHasAny !== patchedHasAny) { + throw Error("Error. Expected node to match patched data."); + } + hasAny = !!patchedHasAny; + } + else { + hasAny = !!nodeHasAny; + } + } + return hasAny; + } + let links = []; + if (!Array.isArray(graph.links)) { + Object.values(graph.links).reduce((acc, v) => { + acc[v.id] = v; + return acc; + }, links); + } + else { + links = graph.links; + } + const linksReverse = [...links]; + linksReverse.reverse(); + for (let l of linksReverse) { + if (!l) + continue; + const link = l.origin_slot != null ? l : extendLink(l); + const originNode = getNodeById(graph, link.origin_id); + const originHasLink = () => nodeHasLinkId(originNode, IoDirection.OUTPUT, link.origin_slot, link.id); + const patchOrigin = (op, id = link.id) => patchNodeSlot(originNode, IoDirection.OUTPUT, link.origin_slot, id, op); + const targetNode = getNodeById(graph, link.target_id); + const targetHasLink = () => nodeHasLinkId(targetNode, IoDirection.INPUT, link.target_slot, link.id); + const targetHasAnyLink = () => nodeHasAnyLink(targetNode, IoDirection.INPUT, link.target_slot); + const patchTarget = (op, id = link.id) => patchNodeSlot(targetNode, IoDirection.INPUT, link.target_slot, id, op); + const originLog = `origin(${link.origin_id}).outputs[${link.origin_slot}].links`; + const targetLog = `target(${link.target_id}).inputs[${link.target_slot}].link`; + if (!originNode || !targetNode) { + if (!originNode && !targetNode) { + !silent && + logger.log(`Link ${link.id} is invalid, ` + + `both origin ${link.origin_id} and target ${link.target_id} do not exist`); + } + else if (!originNode) { + !silent && + logger.log(`Link ${link.id} is funky... ` + + `origin ${link.origin_id} does not exist, but target ${link.target_id} does.`); + if (targetHasLink()) { + !silent && + logger.log(` > [PATCH] ${targetLog} does have link, will remove the inputs' link first.`); + patchTarget("REMOVE", -1); + } + } + else if (!targetNode) { + !silent && + logger.log(`Link ${link.id} is funky... ` + + `target ${link.target_id} does not exist, but origin ${link.origin_id} does.`); + if (originHasLink()) { + !silent && + logger.log(` > [PATCH] Origin's links' has ${link.id}; will remove the link first.`); + patchOrigin("REMOVE"); + } + } + continue; + } + if (targetHasLink() || originHasLink()) { + if (!originHasLink()) { + !silent && + logger.log(`${link.id} is funky... ${originLog} does NOT contain it, but ${targetLog} does.`); + !silent && + logger.log(` > [PATCH] Attempt a fix by adding this ${link.id} to ${originLog}.`); + patchOrigin("ADD"); + } + else if (!targetHasLink()) { + !silent && + logger.log(`${link.id} is funky... ${targetLog} is NOT correct (is ${(_b = (_a = targetNode.inputs) === null || _a === void 0 ? void 0 : _a[link.target_slot]) === null || _b === void 0 ? void 0 : _b.link}), but ${originLog} contains it`); + if (!targetHasAnyLink()) { + !silent && logger.log(` > [PATCH] ${targetLog} is not defined, will set to ${link.id}.`); + let patched = patchTarget("ADD"); + if (!patched) { + !silent && + logger.log(` > [PATCH] Nvm, ${targetLog} already patched. Removing ${link.id} from ${originLog}.`); + patched = patchOrigin("REMOVE"); + } + } + else { + !silent && + logger.log(` > [PATCH] ${targetLog} is defined, removing ${link.id} from ${originLog}.`); + patchOrigin("REMOVE"); + } + } + } + } + for (let l of linksReverse) { + if (!l) + continue; + const link = l.origin_slot != null ? l : extendLink(l); + const originNode = getNodeById(graph, link.origin_id); + const targetNode = getNodeById(graph, link.target_id); + if ((!originNode || !nodeHasLinkId(originNode, IoDirection.OUTPUT, link.origin_slot, link.id)) && + (!targetNode || !nodeHasLinkId(targetNode, IoDirection.INPUT, link.target_slot, link.id))) { + !silent && + logger.log(`${link.id} is def invalid; BOTH origin node ${link.origin_id} ${!originNode ? "is removed" : `doesn\'t have ${link.id}`} and ${link.origin_id} target node ${!targetNode ? "is removed" : `doesn\'t have ${link.id}`}.`); + data.deletedLinks.push(link.id); + continue; + } + } + if (fix) { + for (let i = data.deletedLinks.length - 1; i >= 0; i--) { + !silent && logger.log(`Deleting link #${data.deletedLinks[i]}.`); + if (graph.getNodeById) { + delete graph.links[data.deletedLinks[i]]; + } + else { + graph = graph; + const idx = graph.links.findIndex((l) => l && (l[0] === data.deletedLinks[i] || l.id === data.deletedLinks[i])); + if (idx === -1) { + logger.log(`INDEX NOT FOUND for #${data.deletedLinks[i]}`); + } + logger.log(`splicing ${idx} from links`); + graph.links.splice(idx, 1); + } + } + if (!graph.getNodeById) { + graph.links = graph.links.filter((l) => !!l); + } + } + if (!data.patchedNodes.length && !data.deletedLinks.length) { + return { + hasBadLinks: false, + fixed: false, + graph, + patched: data.patchedNodes.length, + deleted: data.deletedLinks.length, + }; + } + !silent && + logger.log(`${fix ? "Made" : "Would make"} ${data.patchedNodes.length || "no"} node link patches, and ${data.deletedLinks.length || "no"} stale link removals.`); + let hasBadLinks = !!(data.patchedNodes.length || data.deletedLinks.length); + if (fix && !silent) { + const rerun = fixBadLinks(graph, false, true); + hasBadLinks = rerun.hasBadLinks; + } + return { + hasBadLinks, + fixed: !!hasBadLinks && fix, + graph, + patched: data.patchedNodes.length, + deleted: data.deletedLinks.length, + }; +} diff --git a/rgthree-comfy/web/common/media/rgthree.svg b/rgthree-comfy/web/common/media/rgthree.svg new file mode 100644 index 0000000000000000000000000000000000000000..85e22fe25d0491149e95e48b1d02f7cd7cc42090 --- /dev/null +++ b/rgthree-comfy/web/common/media/rgthree.svg @@ -0,0 +1,7 @@ + + + + + + + \ No newline at end of file diff --git a/rgthree-comfy/web/common/media/svgs.js b/rgthree-comfy/web/common/media/svgs.js new file mode 100644 index 0000000000000000000000000000000000000000..ace21b21ae34958d0605ed1eb969eafaa60cabdf --- /dev/null +++ b/rgthree-comfy/web/common/media/svgs.js @@ -0,0 +1,160 @@ +import { createElement as $el } from "../utils_dom.js"; +export const logoRgthree = ``; +export const github = ``; +export const iconStarFilled = ` + + `; +export const iconReplace = ` + + + + + `; +export const iconNode = ` + + + `; +export const iconGear = ` + + `; +export const checkmark = ` + + + + `; +export const logoCivitai = ` + + + + + + + + + + `; +export const iconOutLink = ` + + `; +export const link = ` + + `; +export const pencil = ` + + `; +export const dotdotdot = ` + + + +`; +export const models = ` + + + + +`; +export const pencilColored = ` + + + + + + + + + `; +export const diskColored = ` + + + + + + + +`; +export const folderColored = ` + + +`; +export const modelsColored = ` + + + + +`; +export const legoBlocksColored = ` + + + + + + + + + + + + +`; +export const legoBlockColored = ` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +`; +export const gearColored = ` + + + + +`; +export function $svg(markup, attrs) { + if (!markup.match(/^\s* { + var _a, _b; + const target = getClosestOrSelf(e.target, "[data-callback],menu"); + if (e.which !== 1) { + return; + } + const callback = (_a = target === null || target === void 0 ? void 0 : target.dataset) === null || _a === void 0 ? void 0 : _a['callback']; + if (callback) { + const halt = await ((_b = this.callbacks.get(callback)) === null || _b === void 0 ? void 0 : _b(e)); + if (halt !== false) { + this.close(); + } + } + e.preventDefault(); + e.stopPropagation(); + e.stopImmediatePropagation(); + }); + } + setOptions(options) { + for (const option of options) { + if (option.type === 'title') { + this.element.appendChild($el(`li`, { + html: option.label + })); + } + else { + const id = generateId(8); + this.callbacks.set(id, async (e) => { var _a; return (_a = option === null || option === void 0 ? void 0 : option.callback) === null || _a === void 0 ? void 0 : _a.call(option, e); }); + this.element.appendChild($el(`li[role="button"][data-callback="${id}"]`, { + html: option.label + })); + } + } + } + toElement() { + return this.element; + } + async open(e) { + const parent = e.target.closest('div,dialog,body'); + parent.appendChild(this.element); + setAttributes(this.element, { + style: { + left: `${e.clientX + 16}px`, + top: `${e.clientY - 16}px`, + } + }); + this.element.setAttribute('state', 'measuring-open'); + await wait(16); + const rect = this.element.getBoundingClientRect(); + if (rect.right > window.innerWidth) { + this.element.style.left = `${e.clientX - rect.width - 16}px`; + await wait(16); + } + this.element.setAttribute('state', 'open'); + setTimeout(() => { + window.addEventListener('pointerdown', this.handleWindowPointerDownBound); + }); + } + handleWindowPointerDown(e) { + if (!this.element.contains(e.target)) { + this.close(); + } + } + async close() { + window.removeEventListener('pointerdown', this.handleWindowPointerDownBound); + this.element.setAttribute('state', 'measuring-closed'); + await wait(16); + this.element.setAttribute('state', 'closed'); + this.element.remove(); + } + isOpen() { + return (this.element.getAttribute('state') || '').includes('open'); + } +} +export class MenuButton { + constructor(options) { + this.element = $el('button.rgthree-button[data-action="open-menu"]'); + this.options = options; + this.element.innerHTML = options.icon; + this.menu = new Menu(options.options); + this.element.addEventListener('pointerdown', (e) => { + if (!this.menu.isOpen()) { + this.menu.open(e); + } + }); + } + toElement() { + return this.element; + } +} diff --git a/rgthree-comfy/web/common/model_info_service.js b/rgthree-comfy/web/common/model_info_service.js new file mode 100644 index 0000000000000000000000000000000000000000..b2f2b9ea96e4ca1fd377c5cd895753187d8cab9e --- /dev/null +++ b/rgthree-comfy/web/common/model_info_service.js @@ -0,0 +1,53 @@ +import { rgthreeApi } from "./rgthree_api.js"; +import { api } from "../../scripts/api.js"; +class ModelInfoService extends EventTarget { + constructor() { + super(); + this.loraToInfo = new Map(); + api.addEventListener("rgthree-refreshed-lora-info", this.handleLoraAsyncUpdate.bind(this)); + } + setFreshLoraData(file, info) { + this.loraToInfo.set(file, info); + this.dispatchEvent(new CustomEvent("rgthree-model-service-lora-details", { detail: { lora: info } })); + } + async getLora(file, refresh = false, light = false) { + if (this.loraToInfo.has(file) && !refresh) { + return this.loraToInfo.get(file); + } + return this.fetchLora(file, refresh, light); + } + async fetchLora(file, refresh = false, light = false) { + let info = null; + if (!refresh) { + info = await rgthreeApi.getLorasInfo(file, light); + } + else { + info = await rgthreeApi.refreshLorasInfo(file); + } + if (!light) { + this.loraToInfo.set(file, info); + } + return info; + } + async refreshLora(file) { + return this.fetchLora(file, true); + } + async clearLoraFetchedData(file) { + await rgthreeApi.clearLorasInfo(file); + this.loraToInfo.delete(file); + return null; + } + async saveLoraPartial(file, data) { + let info = await rgthreeApi.saveLoraInfo(file, data); + this.loraToInfo.set(file, info); + return info; + } + handleLoraAsyncUpdate(event) { + var _a; + const info = (_a = event.detail) === null || _a === void 0 ? void 0 : _a.data; + if (info === null || info === void 0 ? void 0 : info.file) { + this.setFreshLoraData(info.file, info); + } + } +} +export const SERVICE = new ModelInfoService(); diff --git a/rgthree-comfy/web/common/progress_bar.js b/rgthree-comfy/web/common/progress_bar.js new file mode 100644 index 0000000000000000000000000000000000000000..e342f0e926f46af0d8b817e01b522d0bc0161260 --- /dev/null +++ b/rgthree-comfy/web/common/progress_bar.js @@ -0,0 +1,179 @@ +import { SERVICE as PROMPT_SERVICE } from "../common/prompt_service.js"; +import { createElement } from "./utils_dom.js"; +export class RgthreeProgressBar extends HTMLElement { + static create() { + return document.createElement(RgthreeProgressBar.NAME); + } + get currentNodeId() { + var _a, _b; + const prompt = this.currentPromptExecution; + const nodeId = ((_a = prompt === null || prompt === void 0 ? void 0 : prompt.errorDetails) === null || _a === void 0 ? void 0 : _a.node_id) || ((_b = prompt === null || prompt === void 0 ? void 0 : prompt.currentlyExecuting) === null || _b === void 0 ? void 0 : _b.nodeId); + return nodeId || null; + } + constructor() { + super(); + this.shadow = null; + this.currentPromptExecution = null; + this.onProgressUpdateBound = this.onProgressUpdate.bind(this); + this.connected = false; + } + onProgressUpdate(e) { + var _a, _b, _c, _d; + if (!this.connected) + return; + const prompt = e.detail.prompt; + this.currentPromptExecution = prompt; + if (prompt === null || prompt === void 0 ? void 0 : prompt.errorDetails) { + let progressText = `${(_a = prompt.errorDetails) === null || _a === void 0 ? void 0 : _a.exception_type} ${((_b = prompt.errorDetails) === null || _b === void 0 ? void 0 : _b.node_id) || ""} ${((_c = prompt.errorDetails) === null || _c === void 0 ? void 0 : _c.node_type) || ""}`; + this.progressTextEl.innerText = progressText; + this.progressNodesEl.classList.add("-error"); + this.progressStepsEl.classList.add("-error"); + return; + } + if (prompt === null || prompt === void 0 ? void 0 : prompt.currentlyExecuting) { + this.progressNodesEl.classList.remove("-error"); + this.progressStepsEl.classList.remove("-error"); + const current = prompt === null || prompt === void 0 ? void 0 : prompt.currentlyExecuting; + let progressText = `(${e.detail.queue}) `; + if (!prompt.totalNodes) { + progressText += `??%`; + this.progressNodesEl.style.width = `0%`; + } + else { + const percent = (prompt.executedNodeIds.length / prompt.totalNodes) * 100; + this.progressNodesEl.style.width = `${Math.max(2, percent)}%`; + progressText += `${Math.round(percent)}%`; + } + let nodeLabel = (_d = current.nodeLabel) === null || _d === void 0 ? void 0 : _d.trim(); + let stepsLabel = ""; + if (current.step != null && current.maxSteps) { + const percent = (current.step / current.maxSteps) * 100; + this.progressStepsEl.style.width = `${percent}%`; + if (current.pass > 1 || current.maxPasses != null) { + stepsLabel += `#${current.pass}`; + if (current.maxPasses && current.maxPasses > 0) { + stepsLabel += `/${current.maxPasses}`; + } + stepsLabel += ` - `; + } + stepsLabel += `${Math.round(percent)}%`; + } + if (nodeLabel || stepsLabel) { + progressText += ` - ${nodeLabel || "???"}${stepsLabel ? ` (${stepsLabel})` : ""}`; + } + if (!stepsLabel) { + this.progressStepsEl.style.width = `0%`; + } + this.progressTextEl.innerText = progressText; + } + else { + if (e === null || e === void 0 ? void 0 : e.detail.queue) { + this.progressTextEl.innerText = `(${e.detail.queue}) Running... in another tab`; + } + else { + this.progressTextEl.innerText = "Idle"; + } + this.progressNodesEl.style.width = `0%`; + this.progressStepsEl.style.width = `0%`; + } + } + connectedCallback() { + if (!this.connected) { + PROMPT_SERVICE.addEventListener("progress-update", this.onProgressUpdateBound); + this.connected = true; + } + if (this.shadow) { + this.progressTextEl.innerText = "Idle"; + this.progressNodesEl.style.width = `0%`; + this.progressStepsEl.style.width = `0%`; + return; + } + this.shadow = this.attachShadow({ mode: "open" }); + const sheet = new CSSStyleSheet(); + sheet.replaceSync(` + + :host { + position: relative; + overflow: hidden; + box-sizing: border-box; + background: var(--rgthree-progress-bg-color); + --rgthree-progress-bg-color: rgba(23, 23, 23, 0.9); + --rgthree-progress-nodes-bg-color: rgb(0, 128, 0); + --rgthree-progress-steps-bg-color: rgb(0, 128, 0); + --rgthree-progress-error-bg-color: rgb(128, 0, 0); + --rgthree-progress-text-color: #fff; + } + :host * { + box-sizing: inherit; + } + + :host > div.bar { + background: var(--rgthree-progress-nodes-bg-color); + position: absolute; + left: 0; + top: 0; + width: 0%; + height: 50%; + z-index: 1; + transition: width 50ms ease-in-out; + } + :host > div.bar + div.bar { + background: var(--rgthree-progress-steps-bg-color); + top: 50%; + height: 50%; + z-index: 2; + } + :host > div.bar.-error { + background: var(--rgthree-progress-error-bg-color); + } + + :host > .overlay { + position: absolute; + left: 0; + top: 0; + width: 100%; + height: 100%; + z-index: 5; + background: linear-gradient(to bottom, rgba(255,255,255,0.25), rgba(0,0,0,0.25)); + mix-blend-mode: overlay; + } + + :host > span { + position: relative; + z-index: 4; + text-align: left; + font-size: inherit; + height: 100%; + font-family: sans-serif; + text-shadow: 1px 1px 0px #000; + display: flex; + flex-direction: row; + padding: 0 6px; + align-items: center; + justify-content: start; + color: var(--rgthree-progress-text-color); + text-shadow: black 0px 0px 2px; + } + + :host > div.bar[style*="width: 0%"]:first-child, + :host > div.bar[style*="width:0%"]:first-child { + height: 0%; + } + :host > div.bar[style*="width: 0%"]:first-child + div, + :host > div.bar[style*="width:0%"]:first-child + div { + bottom: 0%; + } + `); + this.shadow.adoptedStyleSheets = [sheet]; + const overlayEl = createElement(`div.overlay[part="overlay"]`, { parent: this.shadow }); + this.progressNodesEl = createElement(`div.bar[part="progress-nodes"]`, { parent: this.shadow }); + this.progressStepsEl = createElement(`div.bar[part="progress-steps"]`, { parent: this.shadow }); + this.progressTextEl = createElement(`span[part="text"]`, { text: "Idle", parent: this.shadow }); + } + disconnectedCallback() { + this.connected = false; + PROMPT_SERVICE.removeEventListener("progress-update", this.onProgressUpdateBound); + } +} +RgthreeProgressBar.NAME = "rgthree-progress-bar"; +customElements.define(RgthreeProgressBar.NAME, RgthreeProgressBar); diff --git a/rgthree-comfy/web/common/prompt_service.js b/rgthree-comfy/web/common/prompt_service.js new file mode 100644 index 0000000000000000000000000000000000000000..720802925aeb20d8d6cac29319b36690f0965f56 --- /dev/null +++ b/rgthree-comfy/web/common/prompt_service.js @@ -0,0 +1,188 @@ +import { api } from "../../scripts/api.js"; +import { getResolver } from "./shared_utils.js"; +export class PromptExecution { + constructor(id) { + this.promptApi = null; + this.executedNodeIds = []; + this.totalNodes = 0; + this.currentlyExecuting = null; + this.errorDetails = null; + this.apiPrompt = getResolver(); + this.id = id; + } + setPrompt(prompt) { + this.promptApi = prompt.output; + this.totalNodes = Object.keys(this.promptApi).length; + this.apiPrompt.resolve(null); + } + getApiNode(nodeId) { + var _a; + return ((_a = this.promptApi) === null || _a === void 0 ? void 0 : _a[String(nodeId)]) || null; + } + getNodeLabel(nodeId) { + var _a, _b; + const apiNode = this.getApiNode(nodeId); + let label = ((_a = apiNode === null || apiNode === void 0 ? void 0 : apiNode._meta) === null || _a === void 0 ? void 0 : _a.title) || (apiNode === null || apiNode === void 0 ? void 0 : apiNode.class_type) || undefined; + if (!label) { + const graphNode = (_b = this.maybeGetComfyGraph()) === null || _b === void 0 ? void 0 : _b.getNodeById(Number(nodeId)); + label = (graphNode === null || graphNode === void 0 ? void 0 : graphNode.title) || (graphNode === null || graphNode === void 0 ? void 0 : graphNode.type) || undefined; + } + return label; + } + executing(nodeId, step, maxSteps) { + var _a; + if (nodeId == null) { + this.currentlyExecuting = null; + return; + } + if (((_a = this.currentlyExecuting) === null || _a === void 0 ? void 0 : _a.nodeId) !== nodeId) { + if (this.currentlyExecuting != null) { + this.executedNodeIds.push(nodeId); + } + this.currentlyExecuting = { nodeId, nodeLabel: this.getNodeLabel(nodeId), pass: 0 }; + this.apiPrompt.promise.then(() => { + var _a; + if (this.currentlyExecuting == null) { + return; + } + const apiNode = this.getApiNode(nodeId); + if (!this.currentlyExecuting.nodeLabel) { + this.currentlyExecuting.nodeLabel = this.getNodeLabel(nodeId); + } + if ((apiNode === null || apiNode === void 0 ? void 0 : apiNode.class_type) === "UltimateSDUpscale") { + this.currentlyExecuting.pass--; + this.currentlyExecuting.maxPasses = -1; + } + else if ((apiNode === null || apiNode === void 0 ? void 0 : apiNode.class_type) === "IterativeImageUpscale") { + this.currentlyExecuting.maxPasses = (_a = apiNode === null || apiNode === void 0 ? void 0 : apiNode.inputs["steps"]) !== null && _a !== void 0 ? _a : -1; + } + }); + } + if (step != null) { + if (!this.currentlyExecuting.step || step < this.currentlyExecuting.step) { + this.currentlyExecuting.pass++; + } + this.currentlyExecuting.step = step; + this.currentlyExecuting.maxSteps = maxSteps; + } + } + error(details) { + this.errorDetails = details; + } + maybeGetComfyGraph() { + var _a; + return ((_a = window === null || window === void 0 ? void 0 : window.app) === null || _a === void 0 ? void 0 : _a.graph) || null; + } +} +class PromptService extends EventTarget { + constructor(api) { + super(); + this.promptsMap = new Map(); + this.currentExecution = null; + this.lastQueueRemaining = 0; + const that = this; + const queuePrompt = api.queuePrompt; + api.queuePrompt = async function (num, prompt) { + let response; + try { + response = await queuePrompt.apply(api, [...arguments]); + } + catch (e) { + const promptExecution = that.getOrMakePrompt("error"); + promptExecution.error({ exception_type: "Unknown." }); + throw e; + } + const promptExecution = that.getOrMakePrompt(response.prompt_id); + promptExecution.setPrompt(prompt); + if (!that.currentExecution) { + that.currentExecution = promptExecution; + } + that.promptsMap.set(response.prompt_id, promptExecution); + that.dispatchEvent(new CustomEvent("queue-prompt", { + detail: { + prompt: promptExecution, + }, + })); + return response; + }; + api.addEventListener("status", (e) => { + var _a; + if (!((_a = e.detail) === null || _a === void 0 ? void 0 : _a.exec_info)) + return; + this.lastQueueRemaining = e.detail.exec_info.queue_remaining; + this.dispatchProgressUpdate(); + }); + api.addEventListener("execution_start", (e) => { + if (!this.promptsMap.has(e.detail.prompt_id)) { + console.warn("'execution_start' fired before prompt was made."); + } + const prompt = this.getOrMakePrompt(e.detail.prompt_id); + this.currentExecution = prompt; + this.dispatchProgressUpdate(); + }); + api.addEventListener("executing", (e) => { + if (!this.currentExecution) { + this.currentExecution = this.getOrMakePrompt("unknown"); + console.warn("'executing' fired before prompt was made."); + } + this.currentExecution.executing(e.detail); + this.dispatchProgressUpdate(); + if (e.detail == null) { + this.currentExecution = null; + } + }); + api.addEventListener("progress", (e) => { + if (!this.currentExecution) { + this.currentExecution = this.getOrMakePrompt(e.detail.prompt_id); + console.warn("'progress' fired before prompt was made."); + } + this.currentExecution.executing(e.detail.node, e.detail.value, e.detail.max); + this.dispatchProgressUpdate(); + }); + api.addEventListener("execution_cached", (e) => { + if (!this.currentExecution) { + this.currentExecution = this.getOrMakePrompt(e.detail.prompt_id); + console.warn("'execution_cached' fired before prompt was made."); + } + for (const cached of e.detail.nodes) { + this.currentExecution.executing(cached); + } + this.dispatchProgressUpdate(); + }); + api.addEventListener("executed", (e) => { + if (!this.currentExecution) { + this.currentExecution = this.getOrMakePrompt(e.detail.prompt_id); + console.warn("'executed' fired before prompt was made."); + } + }); + api.addEventListener("execution_error", (e) => { + var _a; + if (!this.currentExecution) { + this.currentExecution = this.getOrMakePrompt(e.detail.prompt_id); + console.warn("'execution_error' fired before prompt was made."); + } + (_a = this.currentExecution) === null || _a === void 0 ? void 0 : _a.error(e.detail); + this.dispatchProgressUpdate(); + }); + } + async queuePrompt(prompt) { + return await api.queuePrompt(-1, prompt); + } + dispatchProgressUpdate() { + this.dispatchEvent(new CustomEvent("progress-update", { + detail: { + queue: this.lastQueueRemaining, + prompt: this.currentExecution, + }, + })); + } + getOrMakePrompt(id) { + let prompt = this.promptsMap.get(id); + if (!prompt) { + prompt = new PromptExecution(id); + this.promptsMap.set(id, prompt); + } + return prompt; + } +} +export const SERVICE = new PromptService(api); diff --git a/rgthree-comfy/web/common/rgthree_api.js b/rgthree-comfy/web/common/rgthree_api.js new file mode 100644 index 0000000000000000000000000000000000000000..5e59a1b4754ae8714aae9a915bc408c3a53237c7 --- /dev/null +++ b/rgthree-comfy/web/common/rgthree_api.js @@ -0,0 +1,64 @@ +class RgthreeApi { + constructor(baseUrl) { + this.getCheckpointsPromise = null; + this.getSamplersPromise = null; + this.getSchedulersPromise = null; + this.getLorasPromise = null; + this.getWorkflowsPromise = null; + this.baseUrl = baseUrl || "./rgthree/api"; + } + apiURL(route) { + return `${this.baseUrl}${route}`; + } + fetchApi(route, options) { + return fetch(this.apiURL(route), options); + } + async fetchJson(route, options) { + const r = await this.fetchApi(route, options); + return await r.json(); + } + async postJson(route, json) { + const body = new FormData(); + body.append("json", JSON.stringify(json)); + return await rgthreeApi.fetchJson(route, { method: "POST", body }); + } + getLoras(force = false) { + if (!this.getLorasPromise || force) { + this.getLorasPromise = this.fetchJson("/loras", { cache: "no-store" }); + } + return this.getLorasPromise; + } + async fetchApiJsonOrNull(route, options) { + const response = await this.fetchJson(route, options); + if (response.status === 200 && response.data) { + return response.data || null; + } + return null; + } + async getLorasInfo(...args) { + const params = new URLSearchParams(); + const isSingleLora = typeof args[0] == 'string'; + if (isSingleLora) { + params.set("file", args[0]); + } + params.set("light", (isSingleLora ? args[1] : args[0]) === false ? '0' : '1'); + const path = `/loras/info?` + params.toString(); + return await this.fetchApiJsonOrNull(path); + } + async refreshLorasInfo(file) { + const path = `/loras/info/refresh` + (file ? `?file=${encodeURIComponent(file)}` : ''); + const infos = await this.fetchApiJsonOrNull(path); + return infos; + } + async clearLorasInfo(file) { + const path = `/loras/info/clear` + (file ? `?file=${encodeURIComponent(file)}` : ''); + await this.fetchApiJsonOrNull(path); + return; + } + async saveLoraInfo(lora, data) { + const body = new FormData(); + body.append("json", JSON.stringify(data)); + return await this.fetchApiJsonOrNull(`/loras/info?file=${encodeURIComponent(lora)}`, { cache: "no-store", method: "POST", body }); + } +} +export const rgthreeApi = new RgthreeApi(); diff --git a/rgthree-comfy/web/common/shared_utils.js b/rgthree-comfy/web/common/shared_utils.js new file mode 100644 index 0000000000000000000000000000000000000000..ee24d0a88a5c7f00908e75ee4e5229ecdc928aac --- /dev/null +++ b/rgthree-comfy/web/common/shared_utils.js @@ -0,0 +1,115 @@ +export function getResolver(timeout = 5000) { + const resolver = {}; + resolver.id = generateId(8); + resolver.completed = false; + resolver.resolved = false; + resolver.rejected = false; + resolver.promise = new Promise((resolve, reject) => { + resolver.reject = () => { + resolver.completed = true; + resolver.rejected = true; + reject(); + }; + resolver.resolve = (data) => { + resolver.completed = true; + resolver.resolved = true; + resolve(data); + }; + }); + resolver.timeout = setTimeout(() => { + if (!resolver.completed) { + resolver.reject(); + } + }, timeout); + return resolver; +} +const DEBOUNCE_FN_TO_PROMISE = new WeakMap(); +export function debounce(fn, ms = 64) { + if (!DEBOUNCE_FN_TO_PROMISE.get(fn)) { + DEBOUNCE_FN_TO_PROMISE.set(fn, wait(ms).then(() => { + DEBOUNCE_FN_TO_PROMISE.delete(fn); + fn(); + })); + } + return DEBOUNCE_FN_TO_PROMISE.get(fn); +} +export function wait(ms = 16) { + if (ms === 16) { + return new Promise((resolve) => { + requestAnimationFrame(() => { + resolve(); + }); + }); + } + return new Promise((resolve) => { + setTimeout(() => { + resolve(); + }, ms); + }); +} +function dec2hex(dec) { + return dec.toString(16).padStart(2, "0"); +} +export function generateId(length) { + const arr = new Uint8Array(length / 2); + crypto.getRandomValues(arr); + return Array.from(arr, dec2hex).join(""); +} +export function getObjectValue(obj, objKey, def) { + if (!obj || !objKey) + return def; + const keys = objKey.split("."); + const key = keys.shift(); + const found = obj[key]; + if (keys.length) { + return getObjectValue(found, keys.join("."), def); + } + return found; +} +export function setObjectValue(obj, objKey, value, createMissingObjects = true) { + if (!obj || !objKey) + return obj; + const keys = objKey.split("."); + const key = keys.shift(); + if (obj[key] === undefined) { + if (!createMissingObjects) { + return; + } + obj[key] = {}; + } + if (!keys.length) { + obj[key] = value; + } + else { + if (typeof obj[key] != "object") { + obj[key] = {}; + } + setObjectValue(obj[key], keys.join("."), value, createMissingObjects); + } + return obj; +} +export function moveArrayItem(arr, itemOrFrom, to) { + const from = typeof itemOrFrom === "number" ? itemOrFrom : arr.indexOf(itemOrFrom); + arr.splice(to, 0, arr.splice(from, 1)[0]); +} +export function removeArrayItem(arr, itemOrIndex) { + const index = typeof itemOrIndex === "number" ? itemOrIndex : arr.indexOf(itemOrIndex); + arr.splice(index, 1); +} +export function injectCss(href) { + if (document.querySelector(`link[href^="${href}"]`)) { + return Promise.resolve(); + } + return new Promise((resolve) => { + const link = document.createElement("link"); + link.setAttribute("rel", "stylesheet"); + link.setAttribute("type", "text/css"); + const timeout = setTimeout(resolve, 1000); + link.addEventListener("load", (e) => { + clearInterval(timeout); + resolve(); + }); + link.href = href; + document.head.appendChild(link); + }); +} diff --git a/rgthree-comfy/web/common/utils_dom.js b/rgthree-comfy/web/common/utils_dom.js new file mode 100644 index 0000000000000000000000000000000000000000..a41c0ed923df22ecffe88002dc0b787a33e0d78c --- /dev/null +++ b/rgthree-comfy/web/common/utils_dom.js @@ -0,0 +1,306 @@ +const DIRECT_ATTRIBUTE_MAP = { + cellpadding: 'cellPadding', + cellspacing: 'cellSpacing', + colspan: 'colSpan', + frameborder: 'frameBorder', + height: 'height', + maxlength: 'maxLength', + nonce: 'nonce', + role: 'role', + rowspan: 'rowSpan', + type: 'type', + usemap: 'useMap', + valign: 'vAlign', + width: 'width', +}; +const RGX_NUMERIC_STYLE_UNIT = 'px'; +const RGX_NUMERIC_STYLE = /^((max|min)?(width|height)|margin|padding|(margin|padding)?(left|top|bottom|right)|fontsize|borderwidth)$/i; +const RGX_DEFAULT_VALUE_PROP = /input|textarea|select/i; +function localAssertNotFalsy(input, errorMsg = `Input is not of type.`) { + if (input == null) { + throw new Error(errorMsg); + } + return input; +} +const RGX_STRING_VALID = '[a-z0-9_-]'; +const RGX_TAG = new RegExp(`^([a-z]${RGX_STRING_VALID}*)(\\.|\\[|\\#|$)`, 'i'); +const RGX_ATTR_ID = new RegExp(`#(${RGX_STRING_VALID}+)`, 'gi'); +const RGX_ATTR_CLASS = new RegExp(`(^|\\S)\\.([a-z0-9_\\-\\.]+)`, 'gi'); +const RGX_STRING_CONTENT_TO_SQUARES = '(.*?)(\\[|\\])'; +const RGX_ATTRS_MAYBE_OPEN = new RegExp(`\\[${RGX_STRING_CONTENT_TO_SQUARES}`, 'gi'); +const RGX_ATTRS_FOLLOW_OPEN = new RegExp(`^${RGX_STRING_CONTENT_TO_SQUARES}`, 'gi'); +export function query(selectors, parent = document) { + return Array.from(parent.querySelectorAll(selectors)).filter(n => !!n); +} +export function queryOne(selectors, parent = document) { + var _a; + return (_a = parent.querySelector(selectors)) !== null && _a !== void 0 ? _a : null; +} +export function createText(text) { + return document.createTextNode(text); +} +export function getClosestOrSelf(element, query) { + const el = element; + return ((el === null || el === void 0 ? void 0 : el.closest) && (el.matches(query) && el || el.closest(query))) || null; +} +export function createElement(selectorOrMarkup, attrs) { + const frag = getHtmlFragment(selectorOrMarkup); + let element = frag === null || frag === void 0 ? void 0 : frag.firstElementChild; + let selector = ""; + if (!element) { + selector = selectorOrMarkup.replace(/[\r\n]\s*/g, ""); + const tag = getSelectorTag(selector) || "div"; + element = document.createElement(tag); + selector = selector.replace(RGX_TAG, "$2"); + selector = selector.replace(RGX_ATTR_ID, '[id="$1"]'); + selector = selector.replace(RGX_ATTR_CLASS, (match, p1, p2) => `${p1}[class="${p2.replace(/\./g, " ")}"]`); + } + const selectorAttrs = getSelectorAttributes(selector); + if (selectorAttrs) { + for (const attr of selectorAttrs) { + let matches = attr.substring(1, attr.length - 1).split("="); + let key = localAssertNotFalsy(matches.shift()); + let value = matches.join("="); + if (value === undefined) { + setAttribute(element, key, true); + } + else { + value = value.replace(/^['"](.*)['"]$/, "$1"); + setAttribute(element, key, value); + } + } + } + if (attrs) { + setAttributes(element, attrs); + } + return element; +} +function getSelectorTag(str) { + return tryMatch(str, RGX_TAG); +} +function getSelectorAttributes(selector) { + RGX_ATTRS_MAYBE_OPEN.lastIndex = 0; + let attrs = []; + let result; + while (result = RGX_ATTRS_MAYBE_OPEN.exec(selector)) { + let attr = result[0]; + if (attr.endsWith(']')) { + attrs.push(attr); + } + else { + attr = result[0] + + getOpenAttributesRecursive(selector.substr(RGX_ATTRS_MAYBE_OPEN.lastIndex), 2); + RGX_ATTRS_MAYBE_OPEN.lastIndex += (attr.length - result[0].length); + attrs.push(attr); + } + } + return attrs; +} +function getOpenAttributesRecursive(selectorSubstring, openCount) { + let matches = selectorSubstring.match(RGX_ATTRS_FOLLOW_OPEN); + let result = ''; + if (matches && matches.length) { + result = matches[0]; + openCount += result.endsWith(']') ? -1 : 1; + if (openCount > 0) { + result += getOpenAttributesRecursive(selectorSubstring.substr(result.length), openCount); + } + } + return result; +} +function tryMatch(str, rgx, index = 1) { + var _a; + let found = ''; + try { + found = ((_a = str.match(rgx)) === null || _a === void 0 ? void 0 : _a[index]) || ''; + } + catch (e) { + found = ''; + } + return found; +} +export function setAttributes(element, data) { + let attr; + for (attr in data) { + if (data.hasOwnProperty(attr)) { + setAttribute(element, attr, data[attr]); + } + } +} +function getHtmlFragment(value) { + if (value.match(/^\s*<.*?>[\s\S]*<\/[a-z0-9]+>\s*$/)) { + return document.createRange().createContextualFragment(value.trim()); + } + return null; +} +function getChild(value) { + if (value instanceof Node) { + return value; + } + if (typeof value === 'string') { + let child = getHtmlFragment(value); + if (child) { + return child; + } + if (getSelectorTag(value)) { + return createElement(value); + } + return createText(value); + } + if (value && typeof value.toElement === 'function') { + return value.toElement(); + } + return null; +} +export function setAttribute(element, attribute, value) { + let isRemoving = value == null; + if (attribute === 'default') { + attribute = RGX_DEFAULT_VALUE_PROP.test(element.nodeName) ? 'value' : 'text'; + } + if (attribute === 'text') { + empty(element).appendChild(createText(value != null ? String(value) : '')); + } + else if (attribute === 'html') { + empty(element).innerHTML += value != null ? String(value) : ''; + } + else if (attribute == 'style') { + if (typeof value === 'string') { + element.style.cssText = isRemoving ? '' : (value != null ? String(value) : ''); + } + else { + for (const [styleKey, styleValue] of Object.entries(value)) { + element.style[styleKey] = styleValue; + } + } + } + else if (attribute == 'events') { + for (const [key, fn] of Object.entries(value)) { + addEvent(element, key, fn); + } + } + else if (attribute === 'parent') { + value.appendChild(element); + } + else if (attribute === 'child' || attribute === 'children') { + if (typeof value === 'string' && /^\[[^\[\]]+\]$/.test(value)) { + const parseable = value.replace(/^\[([^\[\]]+)\]$/, '["$1"]').replace(/,/g, '","'); + try { + const parsed = JSON.parse(parseable); + value = parsed; + } + catch (e) { + console.error(e); + } + } + if (attribute === 'children') { + empty(element); + } + let children = value instanceof Array ? value : [value]; + for (let child of children) { + child = getChild(child); + if (child instanceof Node) { + if (element instanceof HTMLTemplateElement) { + element.content.appendChild(child); + } + else { + element.appendChild(child); + } + } + } + } + else if (attribute == 'for') { + element.htmlFor = value != null ? String(value) : ''; + if (isRemoving) { + element.removeAttribute('for'); + } + } + else if (attribute === 'class' || attribute === 'className' || attribute === 'classes') { + element.className = isRemoving ? '' : Array.isArray(value) ? value.join(' ') : String(value); + } + else if (attribute === 'dataset') { + if (typeof value !== 'object') { + console.error('Expecting an object for dataset'); + return; + } + for (const [key, val] of Object.entries(value)) { + element.dataset[key] = String(val); + } + } + else if (attribute == 'onclick' && typeof value === 'function') { + element.addEventListener('click', value); + } + else if (['checked', 'disabled', 'readonly', 'required', 'selected'].includes(attribute)) { + element[attribute] = !!value; + if (!value) { + element.removeAttribute(attribute); + } + else { + element.setAttribute(attribute, attribute); + } + } + else if (DIRECT_ATTRIBUTE_MAP.hasOwnProperty(attribute)) { + if (isRemoving) { + element.removeAttribute(DIRECT_ATTRIBUTE_MAP[attribute]); + } + else { + element.setAttribute(DIRECT_ATTRIBUTE_MAP[attribute], String(value)); + } + } + else if (isRemoving) { + element.removeAttribute(attribute); + } + else { + let oldVal = element.getAttribute(attribute); + if (oldVal !== value) { + element.setAttribute(attribute, String(value)); + } + } +} +function addEvent(element, key, fn) { + element.addEventListener(key, fn); +} +function setStyles(element, styles = null) { + if (styles) { + for (let name in styles) { + setStyle(element, name, styles[name]); + } + } + return element; +} +function setStyle(element, name, value) { + name = (name.indexOf('float') > -1 ? 'cssFloat' : name); + if (name.indexOf('-') != -1) { + name = name.replace(/-\D/g, (match) => { + return match.charAt(1).toUpperCase(); + }); + } + if (value == String(Number(value)) && RGX_NUMERIC_STYLE.test(name)) { + value = value + RGX_NUMERIC_STYLE_UNIT; + } + if (name === 'display' && typeof value !== 'string') { + value = !!value ? null : 'none'; + } + element.style[name] = value === null ? null : String(value); + return element; +} +; +export function empty(element) { + while (element.firstChild) { + element.removeChild(element.firstChild); + } + return element; +} +export function appendChildren(el, children) { + children = !Array.isArray(children) ? [children] : children; + for (let child of children) { + child = getChild(child); + if (child instanceof Node) { + if (el instanceof HTMLTemplateElement) { + el.content.appendChild(child); + } + else { + el.appendChild(child); + } + } + } +} diff --git a/rgthree-comfy/web/common/utils_workflow.js b/rgthree-comfy/web/common/utils_workflow.js new file mode 100644 index 0000000000000000000000000000000000000000..b8c61cfead0aeee76115a1b60e83a69edf271e7c --- /dev/null +++ b/rgthree-comfy/web/common/utils_workflow.js @@ -0,0 +1,55 @@ +import { getResolver } from "./shared_utils.js"; +import { getPngMetadata, getWebpMetadata } from "../../scripts/pnginfo.js"; +function parseWorkflowJson(stringJson) { + stringJson = stringJson || "null"; + stringJson = stringJson.replace(/:\s*NaN/g, ": null"); + return JSON.parse(stringJson); +} +export async function tryToGetWorkflowDataFromEvent(e) { + var _a, _b, _c, _d; + let work; + for (const file of ((_a = e.dataTransfer) === null || _a === void 0 ? void 0 : _a.files) || []) { + const data = await tryToGetWorkflowDataFromFile(file); + if (data.workflow || data.prompt) { + return data; + } + } + const validTypes = ["text/uri-list", "text/x-moz-url"]; + const match = (((_b = e.dataTransfer) === null || _b === void 0 ? void 0 : _b.types) || []).find((t) => validTypes.find((v) => t === v)); + if (match) { + const uri = (_d = (_c = e.dataTransfer.getData(match)) === null || _c === void 0 ? void 0 : _c.split("\n")) === null || _d === void 0 ? void 0 : _d[0]; + if (uri) { + return tryToGetWorkflowDataFromFile(await (await fetch(uri)).blob()); + } + } + return { workflow: null, prompt: null }; +} +export async function tryToGetWorkflowDataFromFile(file) { + var _a; + if (file.type === "image/png") { + const pngInfo = await getPngMetadata(file); + return { + workflow: parseWorkflowJson(pngInfo === null || pngInfo === void 0 ? void 0 : pngInfo.workflow), + prompt: parseWorkflowJson(pngInfo === null || pngInfo === void 0 ? void 0 : pngInfo.prompt), + }; + } + if (file.type === "image/webp") { + const pngInfo = await getWebpMetadata(file); + const workflow = parseWorkflowJson((pngInfo === null || pngInfo === void 0 ? void 0 : pngInfo.workflow) || (pngInfo === null || pngInfo === void 0 ? void 0 : pngInfo.Workflow) || "null"); + const prompt = parseWorkflowJson((pngInfo === null || pngInfo === void 0 ? void 0 : pngInfo.prompt) || (pngInfo === null || pngInfo === void 0 ? void 0 : pngInfo.Prompt) || "null"); + return { workflow, prompt }; + } + if (file.type === "application/json" || ((_a = file.name) === null || _a === void 0 ? void 0 : _a.endsWith(".json"))) { + const resolver = getResolver(); + const reader = new FileReader(); + reader.onload = async () => { + const json = parseWorkflowJson(reader.result); + const isApiJson = Object.values(json).every((v) => v.class_type); + const prompt = isApiJson ? json : null; + const workflow = !isApiJson && !(json === null || json === void 0 ? void 0 : json.templates) ? json : null; + return { workflow, prompt }; + }; + return resolver.promise; + } + return { workflow: null, prompt: null }; +} diff --git a/rgthree-comfy/web/link_fixer/icon_file_json.png b/rgthree-comfy/web/link_fixer/icon_file_json.png new file mode 100644 index 0000000000000000000000000000000000000000..ad3a1cb2b89a2051010d53d69b12cca8735af353 Binary files /dev/null and b/rgthree-comfy/web/link_fixer/icon_file_json.png differ diff --git a/rgthree-comfy/web/link_fixer/index.html b/rgthree-comfy/web/link_fixer/index.html new file mode 100644 index 0000000000000000000000000000000000000000..e998f8e07a8e8787cd2eb8c54c1b935b4cd4a12b --- /dev/null +++ b/rgthree-comfy/web/link_fixer/index.html @@ -0,0 +1,126 @@ + + + + rgthree's comfy: Workflow Link Fixer + + + + +
+

rgthree's Workflow Link Fixer

+

Early versions of the reroute node would occasionally leave behind stale node-linking data in the graph, which could sometimes cause erratic workflow loading. This tool will look at the metadata and attempt to fix these errors.

+

Drag and drop a comfy-generated image or workflow json into this window to check its serialized links and fix.

+ +
+ + + +
+
+ +
+ + + +
+ + + + \ No newline at end of file diff --git a/rgthree-comfy/web/link_fixer/link_page.js b/rgthree-comfy/web/link_fixer/link_page.js new file mode 100644 index 0000000000000000000000000000000000000000..6fb54947edd4d93a3254f50d95ef11d83c3b26a2 --- /dev/null +++ b/rgthree-comfy/web/link_fixer/link_page.js @@ -0,0 +1,195 @@ +import { fixBadLinks } from "../common/link_fixer.js"; +import { getPngMetadata } from "../../scripts/pnginfo.js"; +function wait(ms = 16, value) { + return new Promise((resolve) => { + setTimeout(() => { + resolve(value); + }, ms); + }); +} +const logger = { + logTo: console, + log: (...args) => { + logger.logTo === console + ? console.log(...args) + : (logger.logTo.innerText += args.join(",") + "\n"); + }, +}; +const findBadLinksLogger = { + log: async (...args) => { + logger.log(...args); + }, +}; +export class LinkPage { + constructor() { + this.containerEl = document.querySelector(".box"); + this.figcaptionEl = document.querySelector("figcaption"); + this.outputeMessageEl = document.querySelector(".output"); + this.outputImageEl = document.querySelector(".output-image"); + this.btnFix = document.querySelector(".btn-fix"); + document.addEventListener("dragover", (e) => { + e.preventDefault(); + }, false); + document.addEventListener("drop", (e) => { + this.onDrop(e); + }); + this.btnFix.addEventListener("click", (e) => { + this.onFixClick(e); + }); + } + async onFixClick(e) { + if (!this.graphResults || !this.graph) { + this.updateUi("⛔ Fix button click without results."); + return; + } + let graphFinalResults = fixBadLinks(this.graph, true); + graphFinalResults = fixBadLinks(graphFinalResults.graph, true); + if (graphFinalResults.patched || graphFinalResults.deleted) { + graphFinalResults = fixBadLinks(graphFinalResults.graph, true); + } + this.graphFinalResults = graphFinalResults; + await this.saveFixedWorkflow(); + if (graphFinalResults.hasBadLinks) { + this.updateUi("⛔ Hmm... Still detecting bad links. Can you file an issue at https://github.com/rgthree/rgthree-comfy/issues with your image/workflow."); + } + else { + this.updateUi("✅ Workflow fixed.

Please load new saved workflow json and double check linking and execution."); + } + } + async onDrop(event) { + var _a, _b, _c, _d; + if (!event.dataTransfer) { + return; + } + this.reset(); + event.preventDefault(); + event.stopPropagation(); + if (event.dataTransfer.files.length && ((_b = (_a = event.dataTransfer.files) === null || _a === void 0 ? void 0 : _a[0]) === null || _b === void 0 ? void 0 : _b.type) !== "image/bmp") { + await this.handleFile(event.dataTransfer.files[0]); + return; + } + const validTypes = ["text/uri-list", "text/x-moz-url"]; + const match = [...event.dataTransfer.types].find((t) => validTypes.find((v) => t === v)); + if (match) { + const uri = (_d = (_c = event.dataTransfer.getData(match)) === null || _c === void 0 ? void 0 : _c.split("\n")) === null || _d === void 0 ? void 0 : _d[0]; + if (uri) { + await this.handleFile(await (await fetch(uri)).blob()); + } + } + } + reset() { + this.file = undefined; + this.graph = undefined; + this.graphResults = undefined; + this.graphFinalResults = undefined; + this.updateUi(); + } + updateUi(msg) { + this.outputeMessageEl.innerHTML = ""; + if (this.file && !this.containerEl.classList.contains("-has-file")) { + this.containerEl.classList.add("-has-file"); + this.figcaptionEl.innerHTML = this.file.name || this.file.type; + if (this.file.type === "application/json") { + this.outputImageEl.src = "icon_file_json.png"; + } + else { + const reader = new FileReader(); + reader.onload = () => (this.outputImageEl.src = reader.result); + reader.readAsDataURL(this.file); + } + } + else if (!this.file && this.containerEl.classList.contains("-has-file")) { + this.containerEl.classList.remove("-has-file"); + this.outputImageEl.src = ""; + this.outputImageEl.removeAttribute("src"); + } + if (this.graphResults) { + this.containerEl.classList.add("-has-results"); + if (!this.graphResults.patched && !this.graphResults.deleted) { + this.outputeMessageEl.innerHTML = "✅ No bad links detected in the workflow."; + } + else { + this.containerEl.classList.add("-has-fixable-results"); + this.outputeMessageEl.innerHTML = `⚠️ Found ${this.graphResults.patched} links to fix, and ${this.graphResults.deleted} to be removed.`; + } + } + else { + this.containerEl.classList.remove("-has-results"); + this.containerEl.classList.remove("-has-fixable-results"); + } + if (msg) { + this.outputeMessageEl.innerHTML = msg; + } + } + async handleFile(file) { + this.file = file; + this.updateUi(); + let workflow = null; + if (file.type.startsWith("image/")) { + const pngInfo = await getPngMetadata(file); + workflow = pngInfo === null || pngInfo === void 0 ? void 0 : pngInfo.workflow; + } + else if (file.type === "application/json" || + (file instanceof File && file.name.endsWith(".json"))) { + workflow = await new Promise((resolve) => { + const reader = new FileReader(); + reader.onload = () => { + resolve(reader.result); + }; + reader.readAsText(file); + }); + } + if (!workflow) { + this.updateUi("⛔ No workflow found in dropped item."); + } + else { + try { + this.graph = JSON.parse(workflow); + } + catch (e) { + this.graph = undefined; + } + if (!this.graph) { + this.updateUi("⛔ Invalid workflow found in dropped item."); + } + else { + this.loadGraphData(this.graph); + } + } + } + async loadGraphData(graphData) { + this.graphResults = await fixBadLinks(graphData); + this.updateUi(); + } + async saveFixedWorkflow() { + if (!this.graphFinalResults) { + this.updateUi("⛔ Save w/o final graph patched."); + return false; + } + let filename = this.file.name || "workflow.json"; + let filenames = filename.split("."); + filenames.pop(); + filename = filenames.join("."); + filename += "_fixed.json"; + filename = prompt("Save workflow as:", filename); + if (!filename) + return false; + if (!filename.toLowerCase().endsWith(".json")) { + filename += ".json"; + } + const json = JSON.stringify(this.graphFinalResults.graph, null, 2); + const blob = new Blob([json], { type: "application/json" }); + const url = URL.createObjectURL(blob); + const anchor = document.createElement("a"); + anchor.download = filename; + anchor.href = url; + anchor.style.display = "none"; + document.body.appendChild(anchor); + await wait(); + anchor.click(); + await wait(); + anchor.remove(); + window.URL.revokeObjectURL(url); + return true; + } +} diff --git a/simpleai-seamless-tiled/LICENSE b/simpleai-seamless-tiled/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f288702d2fa16d3cdf0035b15a9fcbc552cd88e7 --- /dev/null +++ b/simpleai-seamless-tiled/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/simpleai-seamless-tiled/README.md b/simpleai-seamless-tiled/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3cef892a4d73a95f75a08bd9d0af711f319f95f4 --- /dev/null +++ b/simpleai-seamless-tiled/README.md @@ -0,0 +1,20 @@ +# ComfyUI-seamless-tiling + +![tile](https://github.com/spacepxl/ComfyUI-seamless-tiling/assets/143970342/2cc548d6-b29e-4e7e-ac89-081498b47fe6) + +ComfyUI nodes for generating seamless textures. Replicates "Tiling" option from A1111, including independent X/Y tiling. + +Use "Seamless Tile" node between loader and samplers to modify model, and "Make Circular VAE" or "Circular VAE Decode" node to decode image. (Make Circular VAE is more efficient, since it only modifies the VAE model once instead of on each decode) + +"Offset Image" node to check for seams. + +Circular VAE Decode code from https://github.com/FlyingFireCo/tiled_ksampler + +X/Y tiling implementation modified from https://github.com/tjm35/asymmetric-tiling-sd-webui + +``` +conditioning/Seamless Tile +latent/Circular VAE Decode (tile) +latent/Make Circular VAE +image/Offset Image +``` diff --git a/simpleai-seamless-tiled/SeamlessTile.py b/simpleai-seamless-tiled/SeamlessTile.py new file mode 100644 index 0000000000000000000000000000000000000000..bded4d80a7ce0273717222fddc14892c1facaf75 --- /dev/null +++ b/simpleai-seamless-tiled/SeamlessTile.py @@ -0,0 +1,259 @@ +import copy +from typing import Optional + +import PIL +import torch +from torch import Tensor +from torch.nn import Conv2d +from torch.nn import functional as F +from torch.nn.modules.utils import _pair +import comfy.samplers +import nodes +from typing import Optional + +class SeamlessTile: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "tiling": (["enable", "x_only", "y_only", "disable"],), + "copy_model": (["Make a copy", "Modify in place"],), + }, + } + + CATEGORY = "SeamlessTile" + + RETURN_TYPES = ("MODEL",) + FUNCTION = "run" + + def run(self, model, copy_model, tiling): + if copy_model == "Modify in place": + model_copy = model + else: + model_copy = copy.deepcopy(model) + + if tiling == "enable": + make_circular_asymm(model_copy.model, True, True) + elif tiling == "x_only": + make_circular_asymm(model_copy.model, True, False) + elif tiling == "y_only": + make_circular_asymm(model_copy.model, False, True) + else: + make_circular_asymm(model_copy.model, False, False) + return (model_copy,) + + +# asymmetric tiling from https://github.com/tjm35/asymmetric-tiling-sd-webui/blob/main/scripts/asymmetric_tiling.py +def make_circular_asymm(model, tileX: bool, tileY: bool): + for layer in [ + layer for layer in model.modules() if isinstance(layer, torch.nn.Conv2d) + ]: + layer.padding_modeX = 'circular' if tileX else 'constant' + layer.padding_modeY = 'circular' if tileY else 'constant' + layer.paddingX = (layer._reversed_padding_repeated_twice[0], layer._reversed_padding_repeated_twice[1], 0, 0) + layer.paddingY = (0, 0, layer._reversed_padding_repeated_twice[2], layer._reversed_padding_repeated_twice[3]) + layer._conv_forward = __replacementConv2DConvForward.__get__(layer, Conv2d) + return model + + +def __replacementConv2DConvForward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): + working = F.pad(input, self.paddingX, mode=self.padding_modeX) + working = F.pad(working, self.paddingY, mode=self.padding_modeY) + return F.conv2d(working, weight, bias, self.stride, _pair(0), self.dilation, self.groups) + + +class CircularVAEDecode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "samples": ("LATENT",), + "vae": ("VAE",), + "tiling": (["enable", "x_only", "y_only", "disable"],) + } + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "decode" + + CATEGORY = "SeamlessTile" + + def decode(self, samples, vae, tiling): + vae_copy = copy.deepcopy(vae) + + if tiling == "enable": + make_circular_asymm(vae_copy.first_stage_model, True, True) + elif tiling == "x_only": + make_circular_asymm(vae_copy.first_stage_model, True, False) + elif tiling == "y_only": + make_circular_asymm(vae_copy.first_stage_model, False, True) + else: + make_circular_asymm(vae_copy.first_stage_model, False, False) + + result = (vae_copy.decode(samples["samples"]),) + return result + + +class MakeCircularVAE: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "vae": ("VAE",), + "tiling": (["enable", "x_only", "y_only", "disable"],), + "copy_vae": (["Make a copy", "Modify in place"],), + } + } + + RETURN_TYPES = ("VAE",) + FUNCTION = "run" + CATEGORY = "SeamlessTile" + + def run(self, vae, tiling, copy_vae): + if copy_vae == "Modify in place": + vae_copy = vae + else: + vae_copy = copy.deepcopy(vae) + + if tiling == "enable": + make_circular_asymm(vae_copy.first_stage_model, True, True) + elif tiling == "x_only": + make_circular_asymm(vae_copy.first_stage_model, True, False) + elif tiling == "y_only": + make_circular_asymm(vae_copy.first_stage_model, False, True) + else: + make_circular_asymm(vae_copy.first_stage_model, False, False) + + return (vae_copy,) + + +class OffsetImage: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "pixels": ("IMAGE",), + "x_percent": ( + "FLOAT", + {"default": 50.0, "min": 0.0, "max": 100.0, "step": 1}, + ), + "y_percent": ( + "FLOAT", + {"default": 50.0, "min": 0.0, "max": 100.0, "step": 1}, + ), + } + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "run" + CATEGORY = "SeamlessTile" + + def run(self, pixels, x_percent, y_percent): + n, y, x, c = pixels.size() + y = round(y * y_percent / 100) + x = round(x * x_percent / 100) + return (pixels.roll((y, x), (1, 2)),) + +class TiledKSampler: + @classmethod + def INPUT_TYPES(cls): + return {"required": + {"model": ("MODEL", ), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "tiling": (["enable", "x_only", "y_only", "disable"],), + "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}), + "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), + "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), + "positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "latent_image": ("LATENT", ), + "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "sample" + + CATEGORY = "SeamlessTile" + def apply_circular(self, model, enable): + for layer in [layer for layer in model.modules() if isinstance(layer, torch.nn.Conv2d)]: + layer.padding_mode = 'circular' if enable else 'zeros' + + def sample(self, model, seed, tiling, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0): + self.apply_circular(model.model, tiling in ["enable", "x_only", "y_only"]) + return nodes.common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise) + + + +class Asymmetric_Tiled_KSampler: + @classmethod + def INPUT_TYPES(cls): + return {"required": + {"model": ("MODEL", ), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "tileX": ("INT", {"default": 1, "min": 0, "max": 1}), + "tileY": ("INT", {"default": 1, "min": 0, "max": 1}), + "startStep": ("INT", {"default": 0, "min": 0, "max": 10000}), + "stopStep": ("INT", {"default": -1, "min": -1, "max": 10000}), + "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}), + "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), + "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), + "positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "latent_image": ("LATENT", ), + "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "sample" + CATEGORY = "SeamlessTile" + + def apply_asymmetric_tiling(self, model, tileX, tileY): + for layer in [layer for layer in model.modules() if isinstance(layer, torch.nn.Conv2d)]: + layer.padding_modeX = 'circular' if tileX else 'constant' + layer.padding_modeY = 'circular' if tileY else 'constant' + layer.paddingX = (layer._reversed_padding_repeated_twice[0], layer._reversed_padding_repeated_twice[1], 0, 0) + layer.paddingY = (0, 0, layer._reversed_padding_repeated_twice[2], layer._reversed_padding_repeated_twice[3]) + print(layer.paddingX, layer.paddingY) + + def __hijackConv2DMethods(self, model, tileX: bool, tileY: bool, startStep: int, stopStep: int): + for layer in [l for l in model.modules() if isinstance(l, torch.nn.Conv2d)]: + layer.padding_modeX = 'circular' if tileX else 'constant' + layer.padding_modeY = 'circular' if tileY else 'constant' + layer.paddingX = (layer._reversed_padding_repeated_twice[0], layer._reversed_padding_repeated_twice[1], 0, 0) + layer.paddingY = (0, 0, layer._reversed_padding_repeated_twice[2], layer._reversed_padding_repeated_twice[3]) + layer.paddingStartStep = startStep + layer.paddingStopStep = stopStep + + def make_bound_method(method, current_layer): + def bound_method(self, *args, **kwargs): # Add 'self' here + return method(current_layer, *args, **kwargs) + return bound_method + + bound_method = make_bound_method(self.__replacementConv2DConvForward, layer) + layer._conv_forward = bound_method.__get__(layer, type(layer)) + + def __replacementConv2DConvForward(self, layer, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]): + step = nodes.common_ksampler.current_step # Assuming there's a way to get the current step + if ((layer.paddingStartStep < 0 or step >= layer.paddingStartStep) and (layer.paddingStopStep < 0 or step <= layer.paddingStopStep)): + working = torch.nn.functional.pad(input, layer.paddingX, mode=layer.padding_modeX) + working = torch.nn.functional.pad(working, layer.paddingY, mode=layer.padding_modeY) + else: + working = torch.nn.functional.pad(input, layer.paddingX, mode='constant') + working = torch.nn.functional.pad(working, layer.paddingY, mode='constant') + return torch.nn.functional.conv2d(working, weight, bias, layer.stride, (0, 0), layer.dilation, layer.groups) + + + def __restoreConv2DMethods(self, model): + for layer in [l for l in model.modules() if isinstance(l, torch.nn.Conv2d)]: + layer._conv_forward = torch.nn.Conv2d._conv_forward.__get__(layer, torch.nn.Conv2d) + + + def sample(self, model, seed, tileX, tileY, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, startStep=0, stopStep=-1): + self.__hijackConv2DMethods(model.model, tileX == 1, tileY == 1, startStep, stopStep) + result = nodes.common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise) + self.__restoreConv2DMethods(model.model) + return result + diff --git a/simpleai-seamless-tiled/__init__.py b/simpleai-seamless-tiled/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb9cffeb4dbe69f9176554c582e25923f6b4d692 --- /dev/null +++ b/simpleai-seamless-tiled/__init__.py @@ -0,0 +1,22 @@ +from .SeamlessTile import (CircularVAEDecode, MakeCircularVAE, OffsetImage, + TiledKSampler, Asymmetric_Tiled_KSampler, SeamlessTile) + +NODE_CLASS_MAPPINGS = { + "TiledKSampler": TiledKSampler, + "AsymmetricTiledKSampler": Asymmetric_Tiled_KSampler, + "SeamlessTile": SeamlessTile, + "CircularVAEDecode": CircularVAEDecode, + "MakeCircularVAE": MakeCircularVAE, + "OffsetImage": OffsetImage, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "TiledKSampler": "KSampler(Tiled)", + "AsymmetricTiledKSampler": "KSampler(AsymmetricTiled)", + "SeamlessTile": "Seamless Tile", + "CircularVAEDecode": "Circular VAE Decode (Tiled)", + "MakeCircularVAE": "Make Circular VAE", + "OffsetImage": "Offset Image", +} + +__all__ = [NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS] diff --git a/simpleai-seamless-tiled/pyproject.toml b/simpleai-seamless-tiled/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..9a964479c09ad3ffe5f6a334616c4fc8e2631a72 --- /dev/null +++ b/simpleai-seamless-tiled/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "comfyui-seamless-tiling" +description = "Node for generating seamless textures, based on similar setting from A1111." +version = "1.0.0" +license = { file = "LICENSE" } + +[project.urls] +Repository = "https://github.com/spinagon/ComfyUI-seamless-tiling" +# Used by Comfy Registry https://comfyregistry.org + +[tool.comfy] +PublisherId = "spinagon" +DisplayName = "ComfyUI-seamless-tiling" +Icon = "" diff --git a/websocket_image_save.py b/websocket_image_save.py new file mode 100644 index 0000000000000000000000000000000000000000..dd8d188843fabb10580443b7df534add0e55e962 --- /dev/null +++ b/websocket_image_save.py @@ -0,0 +1,45 @@ +from PIL import Image, ImageOps +from io import BytesIO +import numpy as np +import struct +import comfy.utils +import time + +#You can use this node to save full size images through the websocket, the +#images will be sent in exactly the same format as the image previews: as +#binary images on the websocket with a 8 byte header indicating the type +#of binary message (first 4 bytes) and the image format (next 4 bytes). + +#Note that no metadata will be put in the images saved with this node. + +class SaveImageWebsocket: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"images": ("IMAGE", ),} + } + + RETURN_TYPES = () + FUNCTION = "save_images" + + OUTPUT_NODE = True + + CATEGORY = "api/image" + + def save_images(self, images): + pbar = comfy.utils.ProgressBar(images.shape[0]) + step = 0 + for image in images: + i = 255. * image.cpu().numpy() + img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) + pbar.update_absolute(step, images.shape[0], ("PNG", img, None)) + step += 1 + + return {} + + def IS_CHANGED(images): + return time.time() + +NODE_CLASS_MAPPINGS = { + "SaveImageWebsocket": SaveImageWebsocket, +} diff --git a/x-flux-comfyui/CHANGELOG.md b/x-flux-comfyui/CHANGELOG.md new file mode 100644 index 0000000000000000000000000000000000000000..763ea175356435fb3e351c087bf1c47f5437f7a6 --- /dev/null +++ b/x-flux-comfyui/CHANGELOG.md @@ -0,0 +1,45 @@ +# Changelog + +## [TODO] + +### Add + +- IPAdapter controlling features +- IPAdapter compatablity with VIT +- Offloading support + +## [Unreleased] + +### Added + + +## [0.4.2] + +### Added +- New clip-vit +- Latent preview +- GGUF Support + +## [0.3.45] + +### Added + +- IP adapter support + + +## [0.2.38] + +### Added + +- Image-to-image support + +### Changed + +- Updated readme.md, added + +## [0.1.0] + +### Added + +- Lora and controlnets support +- xlabs Sampler diff --git a/x-flux-comfyui/Guide.md b/x-flux-comfyui/Guide.md new file mode 100644 index 0000000000000000000000000000000000000000..f0301d069c4d32ec88c77309d8d2dff31a28f7ec --- /dev/null +++ b/x-flux-comfyui/Guide.md @@ -0,0 +1,121 @@ +# Guide + +# Installing + +First of all, you should install ComfyUI and [ComfyUI Manager](https://github.com/ltdrdata/ComfyUI-Manager). + +After that, launch ComfUI. + + + +In the right panel you can find “Manager” button, click on. + +![manager.png](guide/manager.png) + +You will see big window, click on “Custom Nodes Manager”. + +![manager_menu.png](guide/manager_menu.png) + +Go to search field, and start typing “x-flux-comfyui”, + +![search.png](guide/search.png) + +Click “install” button. + +![download.png](guide/download.png) + +Now, you have access to X-Labs nodes, you can find it in “XLabsNodes” category. + +![nodes.png](guide/manager.png) + +# Flux Installing + +There is complete guide by Comfyanonimus [Guide](https://comfyanonymous.github.io/ComfyUI_examples/flux/) + +Important! You should use flux dev, not schnell, if your computer is capable of this, it is better to use fp8 or bf16 (default). + +# Nodes + +## XLabs Sampler: + +![sampler.png](guide/sampler.png) + +Node inputs: + +model: get FLUX diffusion model (from unet loader) + +conditioning & neg_conditioning: input prompts after T5 and clip models (clip only allowed, but you should know, that you will not utilize about 40% of flux power, so use dual text node) + +latent_image: latent input for flux, may be empty latent or encoded with FLUX AE (VAE Encode) image (for image-to-image using) + +controlnet_condition: input for XLabs-AI ControlNet conditioning + +Output: + +latent: FLUX latent image, should be decoded with VAE Decoder to get image + +Parameters: + +noise_seed, control_after_generate: controls random generator + +steps: how many denoise steps will diffusion go + +timestep_to_start_cfg: how many steps diffusion will do before start to use negative sampling and cfg + +true_gs: true cfg scale, will be used after first “**timestep_to_start_cfg” steps** + +image_to_image_strength: how much will the original image affect to the output + +denoise_strength: how many noise will remain + +## Load Flux LoRA + +![lora.png](guide/lora.png) + +Gets flux model on input and outputs model modified with XLabs LoRA (strength of LoRA and name are selected by the parameters) + +## Load Flux ControlNet + +![loadcnet.png](guide/loadcnet.png) + +Gets controlnet name and returns flux controlnet. + +## Apply Flux ControlNet + +![manager.png](guide/manager.png) + +Gets controlnet, image and strength as parameter. Returns controlnet condition for XLabs Sampler. + +## Load Flux IPAdapter + +![loadip.png](guide/loadip.png) + +Gets IP Adapter name, CLIP ViT model and on which device it will work. Choose CUDA only if you have enough VRAM. Return Flux IPAdapter. + +## Apply Flux IPAdapter + +![applyip.png](guide/applyip.png) + +Gets FLUX model, IP Adapter and image. Returns modified model. Strength of IP Adapter comes from parameter. + +## Apply Advanced IPAdatper + +![advip.png](guide/advip.png) + +Like common one, but has more strength parameters. + +# Models + +We use custom folder for LoRAs, ControlNets and IPAdapters, the folders contains in `models\xlabs`. + +LoRAs goes to `ComfyUI\xlabs\loras`. + +ControlNets goes to `ComfyUI\xlabs\controlnets`. + +IPAdapters goes to `ComfyUI\xlabs\ipadapters`. + +An IPAdapter requires a CLIP VIT. We currently use Open-AI Clip ViT Large. You can find it here. + +[CLIP ViT model](https://huggingface.co/openai/clip-vit-large-patch14). + +Download model.safetensors, rename it as you want to (but .safetensors is required). And put it to `ComfyUI\models\clip-vision\`. diff --git a/x-flux-comfyui/LICENSE b/x-flux-comfyui/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/x-flux-comfyui/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/x-flux-comfyui/Readme.md b/x-flux-comfyui/Readme.md new file mode 100644 index 0000000000000000000000000000000000000000..c6f7586430bdfba7f4bb382aa733e71c694aacdc --- /dev/null +++ b/x-flux-comfyui/Readme.md @@ -0,0 +1,52 @@ +# *[Guide](/Guide.md)* + +# How to use +![FLUX Finetuning scripts](./assets/flux-comfy-ui-nodes-dark-rev1.png) +![FLUX Finetuning scripts](./assets/image1.png) + +## Installation: + +1. Go to `ComfyUI/custom_nodes` +2. Clone this repo, path should be `ComfyUI/custom_nodes/x-flux-comfyui/*`, where * is all the files in this repo +3. Go to `ComfyUI/custom_nodes/x-flux-comfyui/` and run `python setup.py` +4. Run ComfyUI after installing and enjoy! + +After the first launch, the `ComfyUI/models/xlabs/loras` and `ComfyUI/models/xlabs/controlnets` folders will be created automatically.
+So, to use lora or controlnet just put models in these folders.
+After that, you may need to click "Refresh" in the user-friendly interface to use the models.
+For controlnet you need install https://github.com/Fannovel16/comfyui_controlnet_aux
+## Low memory mode +You can launch Flux utilizing 12GB VRAM memory usage. +1. Follow installation as described in repo https://github.com/city96/ComfyUI-GGUF +2. Use flux1-dev-Q4_0.gguf from repo https://github.com/city96/ComfyUI-GGUF
+3. Launch ComfyUI with parameters: +```bash +python3 main.py --lowvram --preview-method auto --use-split-cross-attention +``` +In our workflows, replace "Load Diffusion Model" node with "Unet Loader (GGUF)" + +![FLUX Finetuning scripts](./assets/low_memory_mode.png) + +## Models + +We trained **Canny ControlNet**, **Depth ControlNet**, **HED ControlNet** and **LoRA** checkpoints for [`FLUX.1 [dev]`](https://github.com/black-forest-labs/flux)
+You can download them on HuggingFace: + +- [flux-controlnet-collections](https://huggingface.co/XLabs-AI/flux-controlnet-collections) +- [flux-controlnet-canny](https://huggingface.co/XLabs-AI/flux-controlnet-canny) +- [flux-RealismLora](https://huggingface.co/XLabs-AI/flux-RealismLora) +- [flux-lora-collections](https://huggingface.co/XLabs-AI/flux-lora-collection) +- [flux-furry-lora](https://huggingface.co/XLabs-AI/flux-furry-lora) +- [flux-ip-adapter](https://huggingface.co/XLabs-AI/flux-ip-adapter/) +## IP Adapter + +### Instruction +1. Update x-flux-comfy with `git pull` or reinstall it. +2. Download Clip-L `model.safetensors` from [OpenAI VIT CLIP large](https://huggingface.co/openai/clip-vit-large-patch14), and put it to `ComfyUI/models/clip_vision/*`. +3. Download our IPAdapter from [huggingface](https://huggingface.co/XLabs-AI/flux-ip-adapter/tree/main), and put it to `ComfyUI/models/xlabs/ipadapters/*`. +4. Use `Flux Load IPAdapter` and `Apply Flux IPAdapter` nodes, choose right CLIP model and enjoy your genereations. +5. You can find example workflow in folder workflows in this repo. + +### Limitations +The IP Adapter is currently in beta. +We do not guarantee that you will get a good result right away, it may take more attempts to get a result. But we will make efforts to make this process easier and more efficient over time. diff --git a/x-flux-comfyui/__init__.py b/x-flux-comfyui/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e96bd6ab3db650f769ae7886e0c13515752bd16 --- /dev/null +++ b/x-flux-comfyui/__init__.py @@ -0,0 +1,3 @@ +from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS + +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] \ No newline at end of file diff --git a/x-flux-comfyui/assets/flux-comfy-ui-nodes-dark-rev1.png b/x-flux-comfyui/assets/flux-comfy-ui-nodes-dark-rev1.png new file mode 100644 index 0000000000000000000000000000000000000000..f8af26baf009775ac15b2cb6e3c23cc3129bc47d Binary files /dev/null and b/x-flux-comfyui/assets/flux-comfy-ui-nodes-dark-rev1.png differ diff --git a/x-flux-comfyui/assets/image1.png b/x-flux-comfyui/assets/image1.png new file mode 100644 index 0000000000000000000000000000000000000000..28894009fa312987de0b418d4aec392e7f9deae9 --- /dev/null +++ b/x-flux-comfyui/assets/image1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7ef4d278b8464e7a891e11911c5562152a3f1b38917df710b6318c2bdc042fcb +size 1141082 diff --git a/x-flux-comfyui/assets/low_memory_mode.png b/x-flux-comfyui/assets/low_memory_mode.png new file mode 100644 index 0000000000000000000000000000000000000000..470e25bd94678f54bc22e597152df8a7523cb12a Binary files /dev/null and b/x-flux-comfyui/assets/low_memory_mode.png differ diff --git a/x-flux-comfyui/clip.py b/x-flux-comfyui/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..0f5749c9c1a6412d0defad547562ac4a5cea45ea --- /dev/null +++ b/x-flux-comfyui/clip.py @@ -0,0 +1,192 @@ +import json +import os +from transformers import (CLIPImageProcessor, + CLIPVisionModelWithProjection, + CLIPVisionConfig, + AutoConfig) + + + +class FluxClipViT: + def __init__(self, path_model = None): + if path_model is None: + self.model = CLIPVisionModelWithProjection.from_pretrained( + "openai/clip-vit-large-patch14" + ) + + else: + _dir = os.path.dirname(path_model) + write_config(_dir) + config = CLIPVisionConfig.from_pretrained( + os.path.join(_dir, "flux_clip_config.json") + ) + self.model = CLIPVisionModelWithProjection.from_pretrained( + path_model, + config=config, + use_safetensors = True, + ) + self.image_processor = CLIPImageProcessor() + self.load_device = next(self.model.parameters()).device + + def __call__(self, image): + img = self.image_processor( + images=image, return_tensors="pt" + ) + img = img.pixel_values + return self.model(img).image_embeds + + +def write_config(path): + #check if exists + if os.path.exists(os.path.join(path, "flux_clip_config.json")): + return + with open(os.path.join(path, "flux_clip_config.json"), "w") as f: + json.dump(json_config, f, indent=4) + +json_config = {'_name_or_path': 'clip-vit-large-patch14/', + 'architectures': ['CLIPModel'], + 'initializer_factor': 1.0, + 'logit_scale_init_value': 2.6592, + 'model_type': 'clip', + 'projection_dim': 768, + 'text_config': {'_name_or_path': '', + 'add_cross_attention': False, + 'architectures': None, + 'attention_dropout': 0.0, + 'bad_words_ids': None, + 'bos_token_id': 0, + 'chunk_size_feed_forward': 0, + 'cross_attention_hidden_size': None, + 'decoder_start_token_id': None, + 'diversity_penalty': 0.0, + 'do_sample': False, + 'dropout': 0.0, + 'early_stopping': False, + 'encoder_no_repeat_ngram_size': 0, + 'eos_token_id': 2, + 'finetuning_task': None, + 'forced_bos_token_id': None, + 'forced_eos_token_id': None, + 'hidden_act': 'quick_gelu', + 'hidden_size': 768, + 'id2label': {'0': 'LABEL_0', '1': 'LABEL_1'}, + 'initializer_factor': 1.0, + 'initializer_range': 0.02, + 'intermediate_size': 3072, + 'is_decoder': False, + 'is_encoder_decoder': False, + 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, + 'layer_norm_eps': 1e-05, + 'length_penalty': 1.0, + 'max_length': 20, + 'max_position_embeddings': 77, + 'min_length': 0, + 'model_type': 'clip_text_model', + 'no_repeat_ngram_size': 0, + 'num_attention_heads': 12, + 'num_beam_groups': 1, + 'num_beams': 1, + 'num_hidden_layers': 12, + 'num_return_sequences': 1, + 'output_attentions': False, + 'output_hidden_states': False, + 'output_scores': False, + 'pad_token_id': 1, + 'prefix': None, + 'problem_type': None, + 'projection_dim': 768, + 'pruned_heads': {}, + 'remove_invalid_values': False, + 'repetition_penalty': 1.0, + 'return_dict': True, + 'return_dict_in_generate': False, + 'sep_token_id': None, + 'task_specific_params': None, + 'temperature': 1.0, + 'tie_encoder_decoder': False, + 'tie_word_embeddings': True, + 'tokenizer_class': None, + 'top_k': 50, + 'top_p': 1.0, + 'torch_dtype': None, + 'torchscript': False, + 'transformers_version': '4.16.0.dev0', + 'use_bfloat16': False, + 'vocab_size': 49408}, + 'text_config_dict': {'hidden_size': 768, + 'intermediate_size': 3072, + 'num_attention_heads': 12, + 'num_hidden_layers': 12, + 'projection_dim': 768}, + 'torch_dtype': 'float32', + 'transformers_version': None, + 'vision_config': {'_name_or_path': '', + 'add_cross_attention': False, + 'architectures': None, + 'attention_dropout': 0.0, + 'bad_words_ids': None, + 'bos_token_id': None, + 'chunk_size_feed_forward': 0, + 'cross_attention_hidden_size': None, + 'decoder_start_token_id': None, + 'diversity_penalty': 0.0, + 'do_sample': False, + 'dropout': 0.0, + 'early_stopping': False, + 'encoder_no_repeat_ngram_size': 0, + 'eos_token_id': None, + 'finetuning_task': None, + 'forced_bos_token_id': None, + 'forced_eos_token_id': None, + 'hidden_act': 'quick_gelu', + 'hidden_size': 1024, + 'id2label': {'0': 'LABEL_0', '1': 'LABEL_1'}, + 'image_size': 224, + 'initializer_factor': 1.0, + 'initializer_range': 0.02, + 'intermediate_size': 4096, + 'is_decoder': False, + 'is_encoder_decoder': False, + 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, + 'layer_norm_eps': 1e-05, + 'length_penalty': 1.0, + 'max_length': 20, + 'min_length': 0, + 'model_type': 'clip_vision_model', + 'no_repeat_ngram_size': 0, + 'num_attention_heads': 16, + 'num_beam_groups': 1, + 'num_beams': 1, + 'num_hidden_layers': 24, + 'num_return_sequences': 1, + 'output_attentions': False, + 'output_hidden_states': False, + 'output_scores': False, + 'pad_token_id': None, + 'patch_size': 14, + 'prefix': None, + 'problem_type': None, + 'projection_dim': 768, + 'pruned_heads': {}, + 'remove_invalid_values': False, + 'repetition_penalty': 1.0, + 'return_dict': True, + 'return_dict_in_generate': False, + 'sep_token_id': None, + 'task_specific_params': None, + 'temperature': 1.0, + 'tie_encoder_decoder': False, + 'tie_word_embeddings': True, + 'tokenizer_class': None, + 'top_k': 50, + 'top_p': 1.0, + 'torch_dtype': None, + 'torchscript': False, + 'transformers_version': '4.16.0.dev0', + 'use_bfloat16': False}, + 'vision_config_dict': {'hidden_size': 1024, + 'intermediate_size': 4096, + 'num_attention_heads': 16, + 'num_hidden_layers': 24, + 'patch_size': 14, + 'projection_dim': 768}} diff --git a/x-flux-comfyui/guide/advip.png b/x-flux-comfyui/guide/advip.png new file mode 100644 index 0000000000000000000000000000000000000000..c484cc30ba48fca53ca4f07edd03960291cd8e24 Binary files /dev/null and b/x-flux-comfyui/guide/advip.png differ diff --git a/x-flux-comfyui/guide/applycnet.png b/x-flux-comfyui/guide/applycnet.png new file mode 100644 index 0000000000000000000000000000000000000000..d8eef47c00ac1004d75901317e62df5c8abbe622 Binary files /dev/null and b/x-flux-comfyui/guide/applycnet.png differ diff --git a/x-flux-comfyui/guide/applyip.png b/x-flux-comfyui/guide/applyip.png new file mode 100644 index 0000000000000000000000000000000000000000..debda750a20c78fa211a29674f176d12c890901e Binary files /dev/null and b/x-flux-comfyui/guide/applyip.png differ diff --git a/x-flux-comfyui/guide/download.png b/x-flux-comfyui/guide/download.png new file mode 100644 index 0000000000000000000000000000000000000000..d18b50cb3fec0f90960f4cb17fac882239ae8813 Binary files /dev/null and b/x-flux-comfyui/guide/download.png differ diff --git a/x-flux-comfyui/guide/loadcnet.png b/x-flux-comfyui/guide/loadcnet.png new file mode 100644 index 0000000000000000000000000000000000000000..152c4895c60b411a317b0668a051696bbcc14a69 Binary files /dev/null and b/x-flux-comfyui/guide/loadcnet.png differ diff --git a/x-flux-comfyui/guide/loadip.png b/x-flux-comfyui/guide/loadip.png new file mode 100644 index 0000000000000000000000000000000000000000..d57d5f49105d0edea14bf4b2c321e95851d9fc49 Binary files /dev/null and b/x-flux-comfyui/guide/loadip.png differ diff --git a/x-flux-comfyui/guide/lora.png b/x-flux-comfyui/guide/lora.png new file mode 100644 index 0000000000000000000000000000000000000000..f39215ea60ce045f635d4aa0dbbb16495d128104 Binary files /dev/null and b/x-flux-comfyui/guide/lora.png differ diff --git a/x-flux-comfyui/guide/manager.png b/x-flux-comfyui/guide/manager.png new file mode 100644 index 0000000000000000000000000000000000000000..1adc9055c5c47dcc96971fc7d65360b9ec345afe Binary files /dev/null and b/x-flux-comfyui/guide/manager.png differ diff --git a/x-flux-comfyui/guide/manager_menu.png b/x-flux-comfyui/guide/manager_menu.png new file mode 100644 index 0000000000000000000000000000000000000000..f06b5ff75207aa0421ef61630f56d3d06a79f2a9 --- /dev/null +++ b/x-flux-comfyui/guide/manager_menu.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b41216685d605cf75cae0a1a3f61d0464f33ea54136b2e9c6da0713cc94e5f0 +size 109453 diff --git a/x-flux-comfyui/guide/nodes.png b/x-flux-comfyui/guide/nodes.png new file mode 100644 index 0000000000000000000000000000000000000000..aba17dfa10137683c0a414b7c7ed47d05c926046 Binary files /dev/null and b/x-flux-comfyui/guide/nodes.png differ diff --git a/x-flux-comfyui/guide/sampler.png b/x-flux-comfyui/guide/sampler.png new file mode 100644 index 0000000000000000000000000000000000000000..cca7cf2951120c083eef4e50f19f7d8c79f69d51 Binary files /dev/null and b/x-flux-comfyui/guide/sampler.png differ diff --git a/x-flux-comfyui/guide/search.png b/x-flux-comfyui/guide/search.png new file mode 100644 index 0000000000000000000000000000000000000000..54e2d155fe7823e277811ea67606ff289717eb4b Binary files /dev/null and b/x-flux-comfyui/guide/search.png differ diff --git a/x-flux-comfyui/layers.py b/x-flux-comfyui/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..366ca04aba97acb28b2990e6f5a092897201de21 --- /dev/null +++ b/x-flux-comfyui/layers.py @@ -0,0 +1,384 @@ +import math +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn + +from .xflux.src.flux.math import attention, rope +from .xflux.src.flux.modules.layers import LoRALinearLayer + +from torch.nn import functional as F +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + t.device + ) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + +class DoubleStreamBlockLorasMixerProcessor(nn.Module): + def __init__(self,): + super().__init__() + self.qkv_lora1 = [] + self.proj_lora1 = [] + self.qkv_lora2 = [] + self.proj_lora2 = [] + self.lora_weight = [] + self.names = [] + def add_lora(self, processor): + if isinstance(processor, DoubleStreamBlockLorasMixerProcessor): + self.qkv_lora1+=processor.qkv_lora1 + self.qkv_lora2+=processor.qkv_lora2 + self.proj_lora1+=processor.proj_lora1 + self.proj_lora2+=processor.proj_lora2 + self.lora_weight+=processor.lora_weight + else: + if hasattr(processor, "qkv_lora1"): + self.qkv_lora1.append(processor.qkv_lora1) + if hasattr(processor, "proj_lora1"): + self.proj_lora1.append(processor.proj_lora1) + if hasattr(processor, "qkv_lora2"): + self.qkv_lora2.append(processor.qkv_lora2) + if hasattr(processor, "proj_lora2"): + self.proj_lora2.append(processor.proj_lora2) + if hasattr(processor, "lora_weight"): + self.lora_weight.append(processor.lora_weight) + def get_loras(self): + return ( + self.qkv_lora1, self.qkv_lora2, + self.proj_lora1, self.proj_lora2, + self.lora_weight + ) + def set_loras(self, qkv1s, qkv2s, proj1s, proj2s, w8s): + for el in qkv1s: + self.qkv_lora1.append(el) + for el in qkv2s: + self.qkv_lora2.append(el) + for el in proj1s: + self.proj_lora1.append(el) + for el in proj2s: + self.proj_lora2.append(el) + for el in w8s: + self.lora_weight.append(el) + + def add_shift(self, layer, origin, inputs, gating = 1.0): + #shift = torch.zeros_like(origin) + count = len(layer) + for i in range(count): + origin += layer[i](inputs)*self.lora_weight[i]*gating + + def forward(self, attn, img, txt, vec, pe, **attention_kwargs): + img_mod1, img_mod2 = attn.img_mod(vec) + txt_mod1, txt_mod2 = attn.txt_mod(vec) + + # prepare image for attention + img_modulated = attn.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + + #img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight + img_qkv = attn.img_attn.qkv(img_modulated) + #print(self.qkv_lora1) + self.add_shift(self.qkv_lora1, img_qkv, img_modulated) + + + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = attn.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + + + #txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight + txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.add_shift(self.qkv_lora2, txt_qkv, txt_modulated) + + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn1 = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] + + # calculate the img bloks + #img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img_mod1.gate * self.proj_lora1(img_attn) * self.lora_weight + img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + self.add_shift(self.proj_lora1, img, img_attn, img_mod1.gate) + + img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) + + # calculate the txt bloks + #txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt_mod1.gate * self.proj_lora2(txt_attn) * self.lora_weight + txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + self.add_shift(self.proj_lora2, txt, txt_attn, txt_mod1.gate) + + txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) + return img, txt + + +class DoubleStreamBlockLoraProcessor(nn.Module): + def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): + super().__init__() + self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha) + self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha) + self.lora_weight = lora_weight + + def forward(self, attn, img, txt, vec, pe, **attention_kwargs): + img_mod1, img_mod2 = attn.img_mod(vec) + txt_mod1, txt_mod2 = attn.txt_mod(vec) + + # prepare image for attention + img_modulated = attn.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = attn.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn1 = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img_mod1.gate * self.proj_lora1(img_attn) * self.lora_weight + img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt_mod1.gate * self.proj_lora2(txt_attn) * self.lora_weight + txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) + return img, txt + +class DoubleStreamBlockProcessor(nn.Module): + def __init__(self): + super().__init__() + def __call__(self, attn, img, txt, vec, pe, **attention_kwargs): + img_mod1, img_mod2 = attn.img_mod(vec) + txt_mod1, txt_mod2 = attn.txt_mod(vec) + + # prepare image for attention + img_modulated = attn.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = attn.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = attn.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = attn.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn1 = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) + return img, txt + def forward(self, attn, img, txt, vec, pe, **attention_kwargs): + self.__call__(attn, img, txt, vec, pe, **attention_kwargs) + + +class IPProcessor(nn.Module): + def __init__(self, context_dim, hidden_dim, ip_hidden_states=None, ip_scale=None): + super().__init__() + self.ip_hidden_states = ip_hidden_states + self.ip_scale = ip_scale + self.in_hidden_states_neg = None + self.in_hidden_states_pos = ip_hidden_states + # Ensure context_dim matches the dimension of ip_hidden_states + self.context_dim = context_dim + self.hidden_dim = hidden_dim + + # Initialize projections for IP-adapter + self.ip_adapter_double_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=True) + self.ip_adapter_double_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=True) + + nn.init.zeros_(self.ip_adapter_double_stream_k_proj.weight) + nn.init.zeros_(self.ip_adapter_double_stream_k_proj.bias) + + nn.init.zeros_(self.ip_adapter_double_stream_v_proj.weight) + nn.init.zeros_(self.ip_adapter_double_stream_v_proj.bias) + + def forward(self, img_q, attn): + #img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) + # IP-adapter processing + ip_query = img_q # latent sample query + ip_key = self.ip_adapter_double_stream_k_proj(self.ip_hidden_states) + ip_value = self.ip_adapter_double_stream_v_proj(self.ip_hidden_states) + + # Reshape projections for multi-head attention + ip_key = rearrange(ip_key, 'B L (H D) -> B H L D', H=attn.num_heads) + ip_value = rearrange(ip_value, 'B L (H D) -> B H L D', H=attn.num_heads) + #img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + # Compute attention between IP projections and the latent query + ip_attention = F.scaled_dot_product_attention( + ip_query, + ip_key, + ip_value, + dropout_p=0.0, + is_causal=False + ) + ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)", H=attn.num_heads) + return ip_attention*self.ip_scale + +class ImageProjModel(torch.nn.Module): + """Projection Model + https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter.py#L28 + """ + + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): + super().__init__() + + self.generator = None + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + embeds = image_embeds + clip_extra_context_tokens = self.proj(embeds).reshape( + -1, self.clip_extra_context_tokens, self.cross_attention_dim + ) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens + + +class DoubleStreamMixerProcessor(DoubleStreamBlockLorasMixerProcessor): + def __init__(self,): + super().__init__() + self.ip_adapters = nn.ModuleList() + + def add_ipadapter(self, ip_adapter): + self.ip_adapters.append(ip_adapter) + + def get_ip_adapters(self): + return self.ip_adapters + def set_ip_adapters(self, ip_adapters): + self.ip_adapters = ip_adapters + def shift_ip(self, img_qkv, attn, x): + for block in self.ip_adapters: + x += block(img_qkv, attn) + return x + def add_lora(self, processor): + if isinstance(processor, DoubleStreamBlockLorasMixerProcessor): + self.qkv_lora1+=processor.qkv_lora1 + self.qkv_lora2+=processor.qkv_lora2 + self.proj_lora1+=processor.proj_lora1 + self.proj_lora2+=processor.proj_lora2 + self.lora_weight+=processor.lora_weight + elif isinstance(processor, DoubleStreamMixerProcessor): + self.qkv_lora1+=processor.qkv_lora1 + self.qkv_lora2+=processor.qkv_lora2 + self.proj_lora1+=processor.proj_lora1 + self.proj_lora2+=processor.proj_lora2 + self.lora_weight+=processor.lora_weight + else: + if hasattr(processor, "qkv_lora1"): + self.qkv_lora1.append(processor.qkv_lora1) + if hasattr(processor, "proj_lora1"): + self.proj_lora1.append(processor.proj_lora1) + if hasattr(processor, "qkv_lora2"): + self.qkv_lora2.append(processor.qkv_lora2) + if hasattr(processor, "proj_lora2"): + self.proj_lora2.append(processor.proj_lora2) + if hasattr(processor, "lora_weight"): + self.lora_weight.append(processor.lora_weight) + + def forward(self, attn, img, txt, vec, pe, **attention_kwargs): + img_mod1, img_mod2 = attn.img_mod(vec) + txt_mod1, txt_mod2 = attn.txt_mod(vec) + + # prepare image for attention + img_modulated = attn.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + + #img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight + img_qkv = attn.img_attn.qkv(img_modulated) + #print(self.qkv_lora1) + self.add_shift(self.qkv_lora1, img_qkv, img_modulated) + + + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = attn.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + + + #txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight + txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.add_shift(self.qkv_lora2, txt_qkv, txt_modulated) + + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn1 = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] + + # calculate the img bloks + #img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img_mod1.gate * self.proj_lora1(img_attn) * self.lora_weight + img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + self.add_shift(self.proj_lora1, img, img_attn, img_mod1.gate) + img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) + + + self.shift_ip(img_q, attn, img) + # calculate the txt bloks + #txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt_mod1.gate * self.proj_lora2(txt_attn) * self.lora_weight + txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + + + txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) + self.add_shift(self.proj_lora2, txt, txt_attn, txt_mod1.gate) + + return img, txt diff --git a/x-flux-comfyui/model_init.py b/x-flux-comfyui/model_init.py new file mode 100644 index 0000000000000000000000000000000000000000..2ca0d95ee6db6a12ab9ec63c5db6dbd3e341fc2f --- /dev/null +++ b/x-flux-comfyui/model_init.py @@ -0,0 +1,218 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor, nn +from einops import rearrange + + +from .xflux.src.flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, + MLPEmbedder, SingleStreamBlock, + timestep_embedding) +from .xflux.src.flux.model import FluxParams + + +def convert_to_dtype(block, dtype): + block.to(dtype) + return block +def double_blocks_init(model, params, dtype): + model.double_blocks = nn.ModuleList( + [ + convert_to_dtype( + DoubleStreamBlock( + model.hidden_size, + model.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ), + dtype + ) + for _ in range(params.depth) + ] + ) +def single_blocks_init(model, params, dtype): + model.single_blocks = nn.ModuleList( + [ + convert_to_dtype( + SingleStreamBlock(model.hidden_size, model.num_heads, mlp_ratio=params.mlp_ratio), + dtype + ) + for _ in range(params.depth_single_blocks) + + ] + ) + + model.final_layer = LastLayer(model.hidden_size, 1, model.out_channels) + model.final_layer.to(dtype) + + + + +class Flux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + _supports_gradient_checkpointing = True + + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + @property + def attn_processors(self): + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + block_controlnet_hidden_states=None, + guidance: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + if block_controlnet_hidden_states is not None: + controlnet_depth = len(block_controlnet_hidden_states) + for index_block, block in enumerate(self.double_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + img, + txt, + vec, + pe, + ) + else: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + # controlnet residual + if block_controlnet_hidden_states is not None: + img = img + block_controlnet_hidden_states[index_block % 2] + + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + img, + vec, + pe, + ) + else: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img diff --git a/x-flux-comfyui/nodes.py b/x-flux-comfyui/nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..8a33ef45ca8314d283436526caa4299d635d89e1 --- /dev/null +++ b/x-flux-comfyui/nodes.py @@ -0,0 +1,740 @@ +import os + +import comfy.model_management as mm +import comfy.model_patcher as mp +from comfy.utils import ProgressBar +from comfy.clip_vision import load as load_clip_vision +from comfy.clip_vision import clip_preprocess, Output +import latent_preview +import copy + +import folder_paths + +import torch +#from .xflux.src.flux.modules.layers import DoubleStreamBlockLoraProcessor, DoubleStreamBlockProcessor +#from .xflux.src.flux.model import Flux as ModFlux + +from .xflux.src.flux.util import (configs, load_ae, load_clip, + load_flow_model, load_t5, load_safetensors, load_from_repo_id, + load_controlnet) + + +from .utils import (FirstHalfStrengthModel, FluxUpdateModules, LinearStrengthModel, + SecondHalfStrengthModel, SigmoidStrengthModel, attn_processors, + set_attn_processor, + is_model_pathched, merge_loras, LATENT_PROCESSOR_COMFY, + comfy_to_xlabs_lora, check_is_comfy_lora) +from .layers import (DoubleStreamBlockLoraProcessor, + DoubleStreamBlockProcessor, + DoubleStreamBlockLorasMixerProcessor, + DoubleStreamMixerProcessor, + IPProcessor, + ImageProjModel) +from .xflux.src.flux.model import Flux as ModFlux +#from .model_init import double_blocks_init, single_blocks_init + + +from comfy.utils import get_attr, set_attr +from .clip import FluxClipViT + + +dir_xlabs = folder_paths.models_dir # os.path.join(folder_paths.models_dir, "xlabs") +#os.makedirs(dir_xlabs, exist_ok=True) +dir_xlabs_loras = os.path.join(dir_xlabs, "loras") +os.makedirs(dir_xlabs_loras, exist_ok=True) +dir_xlabs_controlnets = os.path.join(dir_xlabs, "controlnet") +os.makedirs(dir_xlabs_controlnets, exist_ok=True) +dir_xlabs_flux = os.path.join(dir_xlabs, "flux") +os.makedirs(dir_xlabs_flux, exist_ok=True) +dir_xlabs_ipadapters = os.path.join(dir_xlabs, "ipadapters") +os.makedirs(dir_xlabs_ipadapters, exist_ok=True) + + +folder_paths.folder_names_and_paths["xlabs"] = ([dir_xlabs], folder_paths.supported_pt_extensions) +folder_paths.folder_names_and_paths["xlabs_loras"] = ([dir_xlabs_loras], folder_paths.supported_pt_extensions) +folder_paths.folder_names_and_paths["xlabs_controlnets"] = ([dir_xlabs_controlnets], folder_paths.supported_pt_extensions) +folder_paths.folder_names_and_paths["xlabs_ipadapters"] = ([dir_xlabs_ipadapters], folder_paths.supported_pt_extensions) +folder_paths.folder_names_and_paths["xlabs_flux"] = ([dir_xlabs_flux], folder_paths.supported_pt_extensions) +folder_paths.folder_names_and_paths["xlabs_flux_json"] = ([dir_xlabs_flux], set({'.json',})) + + + +from .sampling import get_noise, prepare, get_schedule, denoise, denoise_controlnet, unpack +import numpy as np + +def load_flux_lora(path): + if path is not None: + if '.safetensors' in path: + checkpoint = load_safetensors(path) + else: + checkpoint = torch.load(path, map_location='cpu') + else: + checkpoint = None + print("Invalid path") + a1 = sorted(list(checkpoint[list(checkpoint.keys())[0]].shape))[0] + a2 = sorted(list(checkpoint[list(checkpoint.keys())[1]].shape))[0] + if a1==a2: + return checkpoint, int(a1) + return checkpoint, 16 + +def cleanprint(a): + pass#print(a) + return a + +def print_if_not_empty(a): + b = list(a.items()) + if len(b)<1: + return "{}" + return b[0] + +class LoadFluxLora: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "lora_name": (cleanprint(folder_paths.get_filename_list("xlabs_loras")), ), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}), + }} + + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("MODEL",) + FUNCTION = "loadmodel" + CATEGORY = "XLabsNodes" + + def loadmodel(self, model, lora_name, strength_model): + debug=False + + + device=mm.get_torch_device() + offload_device=mm.unet_offload_device() + + is_patched = is_model_pathched(model.model) + + print(f"Is model already patched? {is_patched}") + mul = 1 + if is_patched: + pbar = ProgressBar(5) + else: + mul = 3 + count = len(model.model.diffusion_model.double_blocks) + pbar = ProgressBar(5*mul+count) + + bi = model.clone() + tyanochky = bi.model + + if debug: + print("\n", (print_if_not_empty(bi.object_patches_backup)), "\n___\n", (print_if_not_empty(bi.object_patches)), "\n") + try: + print(get_attr(tyanochky, "diffusion_model.double_blocks.0.processor.lora_weight")) + except: + pass + + pbar.update(mul) + bi.model.to(device) + checkpoint, lora_rank = load_flux_lora(os.path.join(dir_xlabs_loras, lora_name)) + pbar.update(mul) + if not is_patched: + print("We are patching diffusion model, be patient please") + patches=FluxUpdateModules(tyanochky, pbar) + #set_attn_processor(model.model.diffusion_model, DoubleStreamBlockProcessor()) + else: + print("Model already updated") + pbar.update(mul) + #TYANOCHKYBY=16 + + lora_attn_procs = {} + if checkpoint is not None: + if check_is_comfy_lora(checkpoint): + checkpoint = comfy_to_xlabs_lora(checkpoint) + #cached_proccesors = attn_processors(tyanochky.diffusion_model).items() + for name, _ in attn_processors(tyanochky.diffusion_model).items(): + lora_attn_procs[name] = DoubleStreamBlockLoraProcessor( + dim=3072, rank=lora_rank, lora_weight=strength_model) + lora_state_dict = {} + for k in checkpoint.keys(): + if name in k: + lora_state_dict[k[len(name) + 1:]] = checkpoint[k] + lora_attn_procs[name].load_state_dict(lora_state_dict) + lora_attn_procs[name].to(device) + tmp=DoubleStreamMixerProcessor() + tmp.add_lora(lora_attn_procs[name]) + lora_attn_procs[name]=tmp + pbar.update(mul) + #set_attn_processor(tyanochky.diffusion_model, lora_attn_procs) + if debug: + try: + if isinstance( + get_attr(tyanochky, "diffusion_model.double_blocks.0.processor"), + DoubleStreamMixerProcessor + ): + pedovki = get_attr(tyanochky, "diffusion_model.double_blocks.0.processor").lora_weight + if len(pedovki)>0: + altushki="".join([f"{pedov:.2f}, " for pedov in pedovki]) + print(f"Loras applied: {altushki}") + except: + pass + + for name, _ in attn_processors(tyanochky.diffusion_model).items(): + attribute = f"diffusion_model.{name}" + #old = copy.copy(get_attr(bi.model, attribute)) + if attribute in model.object_patches.keys(): + old = copy.copy((model.object_patches[attribute])) + else: + old = None + lora = merge_loras(old, lora_attn_procs[name]) + bi.add_object_patch(attribute, lora) + + + if debug: + print("\n", (print_if_not_empty(bi.object_patches_backup)), "\n_b_\n", (print_if_not_empty(bi.object_patches)), "\n") + print("\n", (print_if_not_empty(model.object_patches_backup)), "\n_m__\n", (print_if_not_empty(model.object_patches)), "\n") + + for _, b in bi.object_patches.items(): + print(b.lora_weight) + break + + #print(get_attr(tyanochky, "diffusion_model.double_blocks.0.processor")) + pbar.update(mul) + return (bi,) + +def load_checkpoint_controlnet(local_path): + if local_path is not None: + if '.safetensors' in local_path: + checkpoint = load_safetensors(local_path) + else: + checkpoint = torch.load(local_path, map_location='cpu') + else: + checkpoint=None + print("Invalid path") + return checkpoint + +class LoadFluxControlNet: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model_name": (["flux-dev", "flux-dev-fp8", "flux-schnell"],), + "controlnet_path": (folder_paths.get_filename_list("xlabs_controlnets"), ), + }} + + RETURN_TYPES = ("FluxControlNet",) + RETURN_NAMES = ("ControlNet",) + FUNCTION = "loadmodel" + CATEGORY = "XLabsNodes" + + def loadmodel(self, model_name, controlnet_path): + device=mm.get_torch_device() + + controlnet = load_controlnet(model_name, device) + checkpoint = load_checkpoint_controlnet(os.path.join(dir_xlabs_controlnets, controlnet_path)) + if checkpoint is not None: + controlnet.load_state_dict(checkpoint) + control_type = "canny" + ret_controlnet = { + "model": controlnet, + "control_type": control_type, + } + return (ret_controlnet,) + +class ApplyFluxControlNet: + @classmethod + def INPUT_TYPES(s): + return {"required": {"controlnet": ("FluxControlNet",), + "image": ("IMAGE", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + + RETURN_TYPES = ("ControlNetCondition",) + RETURN_NAMES = ("controlnet_condition",) + FUNCTION = "prepare" + CATEGORY = "XLabsNodes" + + def prepare(self, controlnet, image, strength): + device=mm.get_torch_device() + controlnet_image = torch.from_numpy((np.array(image) * 2) - 1) + controlnet_image = controlnet_image.permute(0, 3, 1, 2).to(torch.bfloat16).to(device) + + ret_cont = { + "img": controlnet_image, + "controlnet_strength": strength, + "model": controlnet["model"], + "start": 0.0, + "end": 1.0 + } + return (ret_cont,) + +class ApplyAdvancedFluxControlNet: + @classmethod + def INPUT_TYPES(s): + return {"required": {"controlnet": ("FluxControlNet",), + "image": ("IMAGE", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + }} + + RETURN_TYPES = ("ControlNetCondition",) + RETURN_NAMES = ("controlnet_condition",) + FUNCTION = "prepare" + CATEGORY = "XLabsNodes" + + def prepare(self, controlnet, image, strength, start, end): + device=mm.get_torch_device() + controlnet_image = torch.from_numpy((np.array(image) * 2) - 1) + controlnet_image = controlnet_image.permute(0, 3, 1, 2).to(torch.bfloat16).to(device) + + ret_cont = { + "img": controlnet_image, + "controlnet_strength": strength, + "model": controlnet["model"], + "start": start, + "end": end + } + return (ret_cont,) + +class XlabsSampler: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "conditioning": ("CONDITIONING",), + "neg_conditioning": ("CONDITIONING",), + "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "steps": ("INT", {"default": 20, "min": 1, "max": 100}), + "timestep_to_start_cfg": ("INT", {"default": 20, "min": 0, "max": 100}), + "true_gs": ("FLOAT", {"default": 3, "min": 0, "max": 100}), + "image_to_image_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + }, + "optional": { + "latent_image": ("LATENT", {"default": None}), + "controlnet_condition": ("ControlNetCondition", {"default": None}), + } + } + RETURN_TYPES = ("LATENT",) + RETURN_NAMES = ("latent",) + FUNCTION = "sampling" + CATEGORY = "XLabsNodes" + + def sampling(self, model, conditioning, neg_conditioning, + noise_seed, steps, timestep_to_start_cfg, true_gs, + image_to_image_strength, denoise_strength, + latent_image=None, controlnet_condition=None + ): + additional_steps = 11 if controlnet_condition is None else 12 + mm.load_model_gpu(model) + inmodel = model.model + #print(conditioning[0][0].shape) #//t5 + #print(conditioning[0][1]['pooled_output'].shape) #//clip + #print(latent_image['samples'].shape) #// torch.Size([1, 4, 64, 64]) // bc, 4, w//8, h//8 + try: + guidance = conditioning[0][1]['guidance'] + except: + guidance = 1.0 + + device=mm.get_torch_device() + if torch.backends.mps.is_available(): + device = torch.device("mps") + if torch.cuda.is_bf16_supported(): + dtype_model = torch.bfloat16 + else: + dtype_model = torch.float16 + #dtype_model = torch.bfloat16#model.model.diffusion_model.img_in.weight.dtype + offload_device=mm.unet_offload_device() + + torch.manual_seed(noise_seed) + + bc, c, h, w = latent_image['samples'].shape + height = (h//2) * 16 + width = (w//2) * 16 + + x = get_noise( + bc, height, width, device=device, + dtype=dtype_model, seed=noise_seed + ) + orig_x = None + if c==16: + orig_x=latent_image['samples'] + lat_processor2 = LATENT_PROCESSOR_COMFY() + orig_x=lat_processor2.go_back(orig_x) + orig_x=orig_x.to(device, dtype=dtype_model) + + + timesteps = get_schedule( + steps, + (width // 8) * (height // 8) // 4, + shift=True, + ) + try: + inmodel.to(device) + except: + pass + x.to(device) + + inmodel.diffusion_model.to(device) + inp_cond = prepare(conditioning[0][0], conditioning[0][1]['pooled_output'], img=x) + neg_inp_cond = prepare(neg_conditioning[0][0], neg_conditioning[0][1]['pooled_output'], img=x) + + if denoise_strength<=0.99: + try: + timesteps=timesteps[:int(len(timesteps)*denoise_strength)] + except: + pass + # for sampler preview + x0_output = {} + callback = latent_preview.prepare_callback(model, len(timesteps) - 1, x0_output) + + if controlnet_condition is None: + x = denoise( + inmodel.diffusion_model, **inp_cond, timesteps=timesteps, guidance=guidance, + timestep_to_start_cfg=timestep_to_start_cfg, + neg_txt=neg_inp_cond['txt'], + neg_txt_ids=neg_inp_cond['txt_ids'], + neg_vec=neg_inp_cond['vec'], + true_gs=true_gs, + image2image_strength=image_to_image_strength, + orig_image=orig_x, + callback=callback, + width=width, + height=height, + ) + + else: + + controlnet = controlnet_condition['model'] + controlnet_image = controlnet_condition['img'] + controlnet_image = torch.nn.functional.interpolate( + controlnet_image, size=(height, width), scale_factor=None, mode='bicubic',) + controlnet_strength = controlnet_condition['controlnet_strength'] + controlnet_start = controlnet_condition['start'] + controlnet_end = controlnet_condition['end'] + controlnet.to(device, dtype=dtype_model) + controlnet_image=controlnet_image.to(device, dtype=dtype_model) + mm.load_models_gpu([model,]) + #mm.load_model_gpu(controlnet) + + total_steps = len(timesteps) + start_step = int(controlnet_start * total_steps) + end_step = int(controlnet_end * total_steps) + + x = denoise_controlnet( + inmodel.diffusion_model, **inp_cond, controlnet=controlnet, + timesteps=timesteps, guidance=guidance, + controlnet_cond=controlnet_image, + timestep_to_start_cfg=timestep_to_start_cfg, + neg_txt=neg_inp_cond['txt'], + neg_txt_ids=neg_inp_cond['txt_ids'], + neg_vec=neg_inp_cond['vec'], + true_gs=true_gs, + controlnet_gs=controlnet_strength, + image2image_strength=image_to_image_strength, + orig_image=orig_x, + callback=callback, + width=width, + height=height, + controlnet_start_step=start_step, + controlnet_end_step=end_step + ) + #controlnet.to(offload_device) + + x = unpack(x, height, width) + lat_processor = LATENT_PROCESSOR_COMFY() + x = lat_processor(x) + lat_ret = {"samples": x} + + #model.model.to(offload_device) + return (lat_ret,) + + + +class LoadFluxIPAdapter: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "ipadatper": (folder_paths.get_filename_list("xlabs_ipadapters"),), + "clip_vision": (folder_paths.get_filename_list("clip_vision"),), + "provider": (["CPU", "GPU",],), + } + } + RETURN_TYPES = ("IP_ADAPTER_FLUX",) + RETURN_NAMES = ("ipadapterFlux",) + FUNCTION = "loadmodel" + CATEGORY = "XLabsNodes" + + def loadmodel(self, ipadatper, clip_vision, provider): + pbar = ProgressBar(6) + device=mm.get_torch_device() + offload_device=mm.unet_offload_device() + pbar.update(1) + ret_ipa = {} + path = os.path.join(dir_xlabs_ipadapters, ipadatper) + ckpt = load_safetensors(path) + pbar.update(1) + path_clip = folder_paths.get_full_path("clip_vision", clip_vision) + + try: + clip = FluxClipViT(path_clip) + except: + clip = load_clip_vision(path_clip).model + + ret_ipa["clip_vision"] = clip + prefix = "double_blocks." + blocks = {} + proj = {} + for key, value in ckpt.items(): + if key.startswith(prefix): + blocks[key[len(prefix):].replace('.processor.', '.')] = value + if key.startswith("ip_adapter_proj_model"): + proj[key[len("ip_adapter_proj_model."):]] = value + pbar.update(1) + improj = ImageProjModel(4096, 768, 4) + improj.load_state_dict(proj) + pbar.update(1) + ret_ipa["ip_adapter_proj_model"] = improj + + ret_ipa["double_blocks"] = torch.nn.ModuleList([IPProcessor(4096, 3072) for i in range(19)]) + ret_ipa["double_blocks"].load_state_dict(blocks) + #print("\n"*3) + #print(blocks.keys()) + #print("\n"*3) + #print(next(ret_ipa["double_blocks"].parameters())) + #print("\n"*3) + pbar.update(1) + return (ret_ipa,) + + + +class ApplyFluxIPAdapter: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "ip_adapter_flux": ("IP_ADAPTER_FLUX",), + "image": ("IMAGE",), + "strength_model": ("FLOAT", {"default": 0.6, "min": -100.0, "max": 100.0, "step": 0.01}), + }} + + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("MODEL",) + FUNCTION = "applymodel" + CATEGORY = "XLabsNodes" + + def applymodel(self, model, ip_adapter_flux, image, strength_model): + debug=False + + + device=mm.get_torch_device() + offload_device=mm.unet_offload_device() + + is_patched = is_model_pathched(model.model) + + print(f"Is model already patched? {is_patched}") + mul = 1 + if is_patched: + pbar = ProgressBar(5) + else: + mul = 3 + count = len(model.model.diffusion_model.double_blocks) + pbar = ProgressBar(5*mul+count) + + bi = model.clone() + tyanochky = bi.model + + clip = ip_adapter_flux['clip_vision'] + + if isinstance(clip, FluxClipViT): + #torch.Size([1, 526, 526, 3]) + #image = torch.permute(image, (0, )) + #print(image.shape) + #print(image) + clip_device = next(clip.model.parameters()).device + image = torch.clip(image*255, 0.0, 255) + out = clip(image).to(dtype=torch.bfloat16) + neg_out = clip(torch.zeros_like(image)).to(dtype=torch.bfloat16) + else: + print("Using old vit clip") + clip_device = next(clip.parameters()).device + pixel_values = clip_preprocess(image.to(clip_device)).float() + out = clip(pixel_values=pixel_values) + neg_out = clip(pixel_values=torch.zeros_like(pixel_values)) + neg_out = neg_out[2].to(dtype=torch.bfloat16) + out = out[2].to(dtype=torch.bfloat16) + + pbar.update(mul) + if not is_patched: + print("We are patching diffusion model, be patient please") + patches=FluxUpdateModules(tyanochky, pbar) + print("Patched succesfully!") + else: + print("Model already updated") + pbar.update(mul) + + #TYANOCHKYBY=16 + ip_projes_dev = next(ip_adapter_flux['ip_adapter_proj_model'].parameters()).device + ip_adapter_flux['ip_adapter_proj_model'].to(dtype=torch.bfloat16) + ip_projes = ip_adapter_flux['ip_adapter_proj_model'](out.to(ip_projes_dev, dtype=torch.bfloat16)).to(device, dtype=torch.bfloat16) + ip_neg_pr = ip_adapter_flux['ip_adapter_proj_model'](neg_out.to(ip_projes_dev, dtype=torch.bfloat16)).to(device, dtype=torch.bfloat16) + + + ipad_blocks = [] + for block in ip_adapter_flux['double_blocks']: + ipad = IPProcessor(block.context_dim, block.hidden_dim, ip_projes, strength_model) + ipad.load_state_dict(block.state_dict()) + ipad.in_hidden_states_neg = ip_neg_pr + ipad.in_hidden_states_pos = ip_projes + ipad.to(dtype=torch.bfloat16) + npp = DoubleStreamMixerProcessor() + npp.add_ipadapter(ipad) + ipad_blocks.append(npp) + pbar.update(mul) + i=0 + for name, _ in attn_processors(tyanochky.diffusion_model).items(): + attribute = f"diffusion_model.{name}" + #old = copy.copy(get_attr(bi.model, attribute)) + if attribute in model.object_patches.keys(): + old = copy.copy((model.object_patches[attribute])) + else: + old = None + processor = merge_loras(old, ipad_blocks[i]) + processor.to(device, dtype=torch.bfloat16) + bi.add_object_patch(attribute, processor) + i+=1 + pbar.update(mul) + return (bi,) + + + +class ApplyAdvancedFluxIPAdapter: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "ip_adapter_flux": ("IP_ADAPTER_FLUX",), + "image": ("IMAGE",), + "begin_strength": ("FLOAT", {"default": 0.0, "min": -100.0, "max": 100.0, "step": 0.01}), + "end_strength": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}), + "smothing_type": (["Linear", "First half", "Second half", "Sigmoid"],), + }} + + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("MODEL",) + FUNCTION = "applymodel" + CATEGORY = "XLabsNodes" + + def applymodel(self, model, ip_adapter_flux, image, begin_strength, end_strength, smothing_type): + debug=False + + + device=mm.get_torch_device() + offload_device=mm.unet_offload_device() + + is_patched = is_model_pathched(model.model) + + print(f"Is model already patched? {is_patched}") + mul = 1 + if is_patched: + pbar = ProgressBar(5) + else: + mul = 3 + count = len(model.model.diffusion_model.double_blocks) + pbar = ProgressBar(5*mul+count) + + bi = model.clone() + tyanochky = bi.model + + clip = ip_adapter_flux['clip_vision'] + + if isinstance(clip, FluxClipViT): + #torch.Size([1, 526, 526, 3]) + #image = torch.permute(image, (0, )) + #print(image.shape) + #print(image) + clip_device = next(clip.model.parameters()).device + image = torch.clip(image*255, 0.0, 255) + out = clip(image).to(dtype=torch.bfloat16) + neg_out = clip(torch.zeros_like(image)).to(dtype=torch.bfloat16) + else: + print("Using old vit clip") + clip_device = next(clip.parameters()).device + pixel_values = clip_preprocess(image.to(clip_device)).float() + out = clip(pixel_values=pixel_values) + neg_out = clip(pixel_values=torch.zeros_like(pixel_values)) + neg_out = neg_out[2].to(dtype=torch.bfloat16) + out = out[2].to(dtype=torch.bfloat16) + + pbar.update(mul) + if not is_patched: + print("We are patching diffusion model, be patient please") + patches=FluxUpdateModules(tyanochky, pbar) + print("Patched succesfully!") + else: + print("Model already updated") + pbar.update(mul) + + #TYANOCHKYBY=16 + ip_projes_dev = next(ip_adapter_flux['ip_adapter_proj_model'].parameters()).device + ip_adapter_flux['ip_adapter_proj_model'].to(dtype=torch.bfloat16) + out=torch.mean(out, 0) + neg_out=torch.mean(neg_out, 0) + ip_projes = ip_adapter_flux['ip_adapter_proj_model'](out.to(ip_projes_dev, dtype=torch.bfloat16)).to(device, dtype=torch.bfloat16) + ip_neg_pr = ip_adapter_flux['ip_adapter_proj_model'](neg_out.to(ip_projes_dev, dtype=torch.bfloat16)).to(device, dtype=torch.bfloat16) + + + count = len(ip_adapter_flux['double_blocks']) + + if smothing_type == "Linear": + strength_model = LinearStrengthModel(begin_strength, end_strength, count) + elif smothing_type == "First half": + strength_model = FirstHalfStrengthModel(begin_strength, end_strength, count) + elif smothing_type == "Second half": + strength_model = SecondHalfStrengthModel(begin_strength, end_strength, count) + elif smothing_type == "Sigmoid": + strength_model = SigmoidStrengthModel(begin_strength, end_strength, count) + else: + raise ValueError("Invalid smothing type") + + + ipad_blocks = [] + for i, block in enumerate(ip_adapter_flux['double_blocks']): + ipad = IPProcessor(block.context_dim, block.hidden_dim, ip_projes, strength_model[i]) + ipad.load_state_dict(block.state_dict()) + ipad.in_hidden_states_neg = ip_neg_pr + ipad.in_hidden_states_pos = ip_projes + ipad.to(dtype=torch.bfloat16) + npp = DoubleStreamMixerProcessor() + npp.add_ipadapter(ipad) + ipad_blocks.append(npp) + pbar.update(mul) + i=0 + for name, _ in attn_processors(tyanochky.diffusion_model).items(): + attribute = f"diffusion_model.{name}" + #old = copy.copy(get_attr(bi.model, attribute)) + if attribute in model.object_patches.keys(): + old = copy.copy((model.object_patches[attribute])) + else: + old = None + processor = merge_loras(old, ipad_blocks[i]) + processor.to(device, dtype=torch.bfloat16) + bi.add_object_patch(attribute, processor) + i+=1 + pbar.update(mul) + return (bi,) + + + +NODE_CLASS_MAPPINGS = { + "FluxLoraLoader": LoadFluxLora, + "LoadFluxControlNet": LoadFluxControlNet, + "ApplyFluxControlNet": ApplyFluxControlNet, + "ApplyAdvancedFluxControlNet": ApplyAdvancedFluxControlNet, + "XlabsSampler": XlabsSampler, + "ApplyFluxIPAdapter": ApplyFluxIPAdapter, + "LoadFluxIPAdapter": LoadFluxIPAdapter, + "ApplyAdvancedFluxIPAdapter": ApplyAdvancedFluxIPAdapter, +} +NODE_DISPLAY_NAME_MAPPINGS = { + "FluxLoraLoader": "Load Flux LoRA", + "LoadFluxControlNet": "Load Flux ControlNet", + "ApplyFluxControlNet": "Apply Flux ControlNet", + "ApplyAdvancedFluxControlNet": "Apply Advanced Flux ControlNet", + "XlabsSampler": "Xlabs Sampler", + "ApplyFluxIPAdapter": "Apply Flux IPAdapter", + "LoadFluxIPAdapter": "Load Flux IPAdatpter", + "ApplyAdvancedFluxIPAdapter": "Apply Advanced Flux IPAdapter", +} diff --git a/x-flux-comfyui/pyproject.toml b/x-flux-comfyui/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..6b7095438c26bb1f7f2e54c8231d6956eeb1fe9c --- /dev/null +++ b/x-flux-comfyui/pyproject.toml @@ -0,0 +1,15 @@ +[project] +name = "x-flux-comfyui" +description = "Nodes:Load Flux LoRA, Load Flux ControlNet, Apply Flux ControlNet, Xlabs Sampler" +version = "1.0.0" +license = {file = "LICENSE"} +dependencies = ["GitPython", "einops==0.8.0", "transformers", "diffusers", "sentencepiece", "opencv-python"] + +[project.urls] +Repository = "https://github.com/XLabs-AI/x-flux-comfyui" +# Used by Comfy Registry https://comfyregistry.org + +[tool.comfy] +PublisherId = "" +DisplayName = "x-flux-comfyui" +Icon = "" diff --git a/x-flux-comfyui/requirements.txt b/x-flux-comfyui/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..de880acee3b1075fddb1e08ff5af62fe45f1b0d8 --- /dev/null +++ b/x-flux-comfyui/requirements.txt @@ -0,0 +1,6 @@ +GitPython +einops==0.8.0 +transformers +diffusers +sentencepiece +opencv-python diff --git a/x-flux-comfyui/sampling.py b/x-flux-comfyui/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..8e2480bfb1d91fd90ec0b8fec49b4adf3a4aabc0 --- /dev/null +++ b/x-flux-comfyui/sampling.py @@ -0,0 +1,352 @@ +import math +from typing import Callable + +import torch +from einops import rearrange, repeat +from torch import Tensor +import numpy as np + +#from .modules.conditioner import HFEmbedder +from .layers import DoubleStreamMixerProcessor, timestep_embedding +from tqdm.auto import tqdm + +def model_forward( + model, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + block_controlnet_hidden_states=None, + guidance: Tensor | None = None, + neg_mode: bool | None = False, +) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + # running on sequences img + img = model.img_in(img) + vec = model.time_in(timestep_embedding(timesteps, 256)) + if model.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + model.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + model.vector_in(y) + txt = model.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = model.pe_embedder(ids) + if block_controlnet_hidden_states is not None: + controlnet_depth = len(block_controlnet_hidden_states) + for index_block, block in enumerate(model.double_blocks): + if hasattr(block, "processor"): + if isinstance(block.processor, DoubleStreamMixerProcessor): + if neg_mode: + for ip in block.processor.ip_adapters: + ip.ip_hidden_states = ip.in_hidden_states_neg + else: + for ip in block.processor.ip_adapters: + ip.ip_hidden_states = ip.in_hidden_states_pos + + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + # controlnet residual + + if block_controlnet_hidden_states is not None: + img = img + block_controlnet_hidden_states[index_block % 2] + + + img = torch.cat((txt, img), 1) + for block in model.single_blocks: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + img = model.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img + +def get_noise( + num_samples: int, + height: int, + width: int, + device: torch.device, + dtype: torch.dtype, + seed: int, +): + return torch.randn( + num_samples, + 16, + # allow for packing + 2 * math.ceil(height / 16), + 2 * math.ceil(width / 16), + device=device, + dtype=dtype, + generator=torch.Generator(device=device).manual_seed(seed), + ) + + +def prepare(txt_t5, vec_clip, img: Tensor) -> dict[str, Tensor]: + txt = txt_t5 + vec = vec_clip + bs, c, h, w = img.shape + + + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img.shape[0] == 1 and bs > 1: + img = repeat(img, "1 ... -> bs ...", bs=bs) + + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + + if txt.shape[0] == 1 and bs > 1: + txt = repeat(txt, "1 ... -> bs ...", bs=bs) + + txt_ids = torch.zeros(bs, txt.shape[1], 3) + + if vec.shape[0] == 1 and bs > 1: + vec = repeat(vec, "1 ... -> bs ...", bs=bs) + + return { + "img": img, + "img_ids": img_ids.to(img.device, dtype=img.dtype), + "txt": txt.to(img.device, dtype=img.dtype), + "txt_ids": txt_ids.to(img.device, dtype=img.dtype), + "vec": vec.to(img.device, dtype=img.dtype), + } + + +def time_shift(mu: float, sigma: float, t: Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function( + x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 +) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model, + # model input + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + vec: Tensor, + neg_txt: Tensor, + neg_txt_ids: Tensor, + neg_vec: Tensor, + # sampling parameters + timesteps: list[float], + guidance: float = 4.0, + true_gs = 1, + timestep_to_start_cfg=0, + image2image_strength=None, + orig_image = None, + callback = None, + width = 512, + height = 512, +): + i = 0 + + #init_latents = rearrange(init_latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if image2image_strength is not None and orig_image is not None: + + t_idx = int((1 - np.clip(image2image_strength, 0.0, 1.0)) * len(timesteps)) + t = timesteps[t_idx] + timesteps = timesteps[t_idx:] + orig_image = rearrange(orig_image, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2).to(img.device, dtype = img.dtype) + img = t * img + (1.0 - t) * orig_image + img_ids=img_ids.to(img.device, dtype=img.dtype) + txt=txt.to(img.device, dtype=img.dtype) + txt_ids=txt_ids.to(img.device, dtype=img.dtype) + vec=vec.to(img.device, dtype=img.dtype) + if hasattr(model, "guidance_in"): + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + else: + # this is ignored for schnell + guidance_vec = None + for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:]), desc="Sampling", total = len(timesteps)-1): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + pred = model_forward( + model, + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + ) + if i >= timestep_to_start_cfg: + neg_pred = model_forward( + model, + img=img, + img_ids=img_ids, + txt=neg_txt, + txt_ids=neg_txt_ids, + y=neg_vec, + timesteps=t_vec, + guidance=guidance_vec, + neg_mode = True, + ) + pred = neg_pred + true_gs * (pred - neg_pred) + img = img + (t_prev - t_curr) * pred + + if callback is not None: + unpacked = unpack(img.float(), height, width) + callback(step=i, x=img, x0=unpacked, total_steps=len(timesteps) - 1) + i += 1 + + return img + +def denoise_controlnet( + model, + controlnet: None, + # model input + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + vec: Tensor, + neg_txt: Tensor, + neg_txt_ids: Tensor, + neg_vec: Tensor, + controlnet_cond, + # sampling parameters + timesteps: list[float], + guidance: float = 4.0, + true_gs = 1, + controlnet_gs=0.7, + timestep_to_start_cfg=0, + image2image_strength=None, + orig_image = None, + callback = None, + width = 512, + height = 512, + controlnet_start_step=0, + controlnet_end_step=None +): + i = 0 + + if image2image_strength is not None and orig_image is not None: + t_idx = int((1 - np.clip(image2image_strength, 0.0, 1.0)) * len(timesteps)) + t = timesteps[t_idx] + timesteps = timesteps[t_idx:] + orig_image = rearrange(orig_image, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2).to(img.device, dtype = img.dtype) + img = t * img + (1.0 - t) * orig_image + + img_ids = img_ids.to(img.device, dtype=img.dtype) + txt = txt.to(img.device, dtype=img.dtype) + txt_ids = txt_ids.to(img.device, dtype=img.dtype) + vec = vec.to(img.device, dtype=img.dtype) + controlnet.to(img.device, dtype=img.dtype) + controlnet_cond = controlnet_cond.to(img.device, dtype=img.dtype) + + if hasattr(model, "guidance_in"): + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + else: + guidance_vec = None + + if controlnet_end_step is None: + controlnet_end_step = len(timesteps) - 1 + + for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:]), desc="Sampling", total=len(timesteps)-1): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + guidance_vec = guidance_vec.to(img.device, dtype=img.dtype) + + if controlnet_start_step <= i <= controlnet_end_step: + block_res_samples = controlnet( + img=img, + img_ids=img_ids, + controlnet_cond=controlnet_cond, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + ) + controlnet_hidden_states = [sample * controlnet_gs for sample in block_res_samples] + else: + controlnet_hidden_states = None + + pred = model_forward( + model, + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + block_controlnet_hidden_states=controlnet_hidden_states + ) + + if i >= timestep_to_start_cfg: + if controlnet_start_step <= i <= controlnet_end_step: + neg_block_res_samples = controlnet( + img=img, + img_ids=img_ids, + controlnet_cond=controlnet_cond, + txt=neg_txt, + txt_ids=neg_txt_ids, + y=neg_vec, + timesteps=t_vec, + guidance=guidance_vec, + ) + neg_controlnet_hidden_states = [sample * controlnet_gs for sample in neg_block_res_samples] + else: + neg_controlnet_hidden_states = None + + neg_pred = model_forward( + model, + img=img, + img_ids=img_ids, + txt=neg_txt, + txt_ids=neg_txt_ids, + y=neg_vec, + timesteps=t_vec, + guidance=guidance_vec, + block_controlnet_hidden_states=neg_controlnet_hidden_states, + neg_mode=True, + ) + pred = neg_pred + true_gs * (pred - neg_pred) + img = img + (t_prev - t_curr) * pred + + if callback is not None: + unpacked = unpack(img.float(), height, width) + callback(step=i, x=img, x0=unpacked, total_steps=len(timesteps) - 1) + i += 1 + return img + +def unpack(x: Tensor, height: int, width: int) -> Tensor: + return rearrange( + x, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(height / 16), + w=math.ceil(width / 16), + ph=2, + pw=2, + ) diff --git a/x-flux-comfyui/setup-py.bat b/x-flux-comfyui/setup-py.bat new file mode 100644 index 0000000000000000000000000000000000000000..d0e0fc80af44d56237a5cce04e9213cacb34b4e1 --- /dev/null +++ b/x-flux-comfyui/setup-py.bat @@ -0,0 +1,15 @@ +@echo off + +set "python_exec=..\..\..\python_embedded\python.exe" + +echo Installing node... + +if exist "%python_exec%" ( + echo Installing with ComfyUI Portable + "%python_exec%" setup.py" +) else ( + echo Installing with system Python + setup.py" +) + +pause diff --git a/x-flux-comfyui/setup.py b/x-flux-comfyui/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..bdf32f0f72b9ecf9c6a45d57016a9c89c2499e27 --- /dev/null +++ b/x-flux-comfyui/setup.py @@ -0,0 +1,20 @@ +import os +if False: + try: + import git + git.Git(".").clone("https://github.com/XLabs-AI/x-flux") + except: + + os.system("git clone https://github.com/XLabs-AI/x-flux" ) +#os.rename("x-flux", "xflux") +cur_dir = os.path.dirname(os.path.abspath(__file__)) +if False: + run = f'mv x-flux "{cur_dir}/xflux"' + if os.name == 'nt': + run = f'move x-flux "{cur_dir}\\xflux"' + os.system(run) +if os.name == 'nt': + os.system(f'pip install -r "{cur_dir}\\requirements.txt"') +else: + os.system(f'pip install -r "{cur_dir}/requirements.txt"') +print("Succesfully installed") diff --git a/x-flux-comfyui/utils.py b/x-flux-comfyui/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6cb0c58f3f7d1fb5d3aa3e0c80d9f6a2617720c0 --- /dev/null +++ b/x-flux-comfyui/utils.py @@ -0,0 +1,264 @@ +from comfy.ldm.flux.layers import DoubleStreamBlock as DSBold +import copy +import torch +from .xflux.src.flux.modules.layers import DoubleStreamBlock as DSBnew +from .layers import (DoubleStreamBlockLoraProcessor, + DoubleStreamBlockProcessor, + DoubleStreamBlockLorasMixerProcessor, + DoubleStreamMixerProcessor) + +from comfy.utils import get_attr, set_attr + +import numpy as np + +def CopyDSB(oldDSB): + + if isinstance(oldDSB, DSBold): + tyan = copy.copy(oldDSB) + + if hasattr(tyan.img_mlp[0], 'out_features'): + mlp_hidden_dim = tyan.img_mlp[0].out_features + else: + mlp_hidden_dim = 12288 + + mlp_ratio = mlp_hidden_dim / tyan.hidden_size + bi = DSBnew(hidden_size=tyan.hidden_size, num_heads=tyan.num_heads, mlp_ratio=mlp_ratio) + #better use __dict__ but I bit scared + ( + bi.img_mod, bi.img_norm1, bi.img_attn, bi.img_norm2, + bi.img_mlp, bi.txt_mod, bi.txt_norm1, bi.txt_attn, bi.txt_norm2, bi.txt_mlp + ) = ( + tyan.img_mod, tyan.img_norm1, tyan.img_attn, tyan.img_norm2, + tyan.img_mlp, tyan.txt_mod, tyan.txt_norm1, tyan.txt_attn, tyan.txt_norm2, tyan.txt_mlp + ) + bi.set_processor(DoubleStreamBlockProcessor()) + + return bi + return oldDSB + +def copy_model(orig, new): + new = copy.copy(new) + new.model = copy.copy(orig.model) + new.model.diffusion_model = copy.copy(orig.model.diffusion_model) + new.model.diffusion_model.double_blocks = copy.deepcopy(orig.model.diffusion_model.double_blocks) + count = len(new.model.diffusion_model.double_blocks) + for i in range(count): + new.model.diffusion_model.double_blocks[i] = copy.copy(orig.model.diffusion_model.double_blocks[i]) + new.model.diffusion_model.double_blocks[i].load_state_dict(orig.model.diffusion_model.double_blocks[0].state_dict()) +""" +class PbarWrapper: + def __init__(self): + self.count = 1 + self.weights = [] + self.counts = [] + self.w8ts = [] + self.rn = 0 + self.rnf = 0.0 + def add(self, count, weight): + self.weights.append(weight) + self.counts.append(count) + wa = np.array(self.weights) + wa = wa/np.sum(wa) + ca = np.array(self.counts) + ml = np.multiply(ca, wa) + cas = np.sum(ml) + self.count=int(cas) + self.w8ts = wa.tolist() + def start(self): + self.rnf = 0.0 + self.rn = 0 + def __call__(self): + self.rn+=1 + return 1 +""" +def FluxUpdateModules(flux_model, pbar=None): + save_list = {} + #print((flux_model.diffusion_model.double_blocks)) + #for k,v in flux_model.diffusion_model.double_blocks: + #if "double" in k: + count = len(flux_model.diffusion_model.double_blocks) + patches = {} + + for i in range(count): + if pbar is not None: + pbar.update(1) + patches[f"double_blocks.{i}"]=CopyDSB(flux_model.diffusion_model.double_blocks[i]) + flux_model.diffusion_model.double_blocks[i]=CopyDSB(flux_model.diffusion_model.double_blocks[i]) + return patches + +def is_model_pathched(model): + def test(mod): + if isinstance(mod, DSBnew): + return True + else: + for p in mod.children(): + if test(p): + return True + return False + result = test(model) + return result + + + +def attn_processors(model_flux): + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, procs): + + if hasattr(module, "set_processor"): + procs[f"{name}.processor"] = module.processor + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, procs) + + return procs + + for name, module in model_flux.named_children(): + fn_recursive_add_processors(name, module, processors) + return processors +def merge_loras(lora1, lora2): + new_block = DoubleStreamMixerProcessor() + if isinstance(lora1, DoubleStreamMixerProcessor): + new_block.set_loras(*lora1.get_loras()) + new_block.set_ip_adapters(lora1.get_ip_adapters()) + elif isinstance(lora1, DoubleStreamBlockLoraProcessor): + new_block.add_lora(lora1) + else: + pass + if isinstance(lora2, DoubleStreamMixerProcessor): + new_block.set_loras(*lora2.get_loras()) + new_block.set_ip_adapters(lora2.get_ip_adapters()) + elif isinstance(lora2, DoubleStreamBlockLoraProcessor): + new_block.add_lora(lora2) + else: + pass + return new_block + +def set_attn_processor(model_flux, processor): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(attn_processors(model_flux).keys()) + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if isinstance(module.get_processor(), DoubleStreamBlockLorasMixerProcessor): + block = copy.copy(module.get_processor()) + module.set_processor(copy.deepcopy(module.get_processor())) + new_block = DoubleStreamBlockLorasMixerProcessor() + #q1, q2, p1, p2, w1 = block.get_loras() + new_block.set_loras(*block.get_loras()) + if not isinstance(processor, dict): + new_block.add_lora(processor) + else: + + new_block.add_lora(processor.pop(f"{name}.processor")) + module.set_processor(new_block) + #block = set_attr(module, "", new_block) + elif isinstance(module.get_processor(), DoubleStreamBlockLoraProcessor): + block = DoubleStreamBlockLorasMixerProcessor() + block.add_lora(copy.copy(module.get_processor())) + if not isinstance(processor, dict): + block.add_lora(processor) + else: + block.add_lora(processor.pop(f"{name}.processor")) + module.set_processor(block) + else: + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in model_flux.named_children(): + fn_recursive_attn_processor(name, module, processor) + +class LATENT_PROCESSOR_COMFY: + def __init__(self): + self.scale_factor = 0.3611 + self.shift_factor = 0.1159 + self.latent_rgb_factors =[ + [-0.0404, 0.0159, 0.0609], + [ 0.0043, 0.0298, 0.0850], + [ 0.0328, -0.0749, -0.0503], + [-0.0245, 0.0085, 0.0549], + [ 0.0966, 0.0894, 0.0530], + [ 0.0035, 0.0399, 0.0123], + [ 0.0583, 0.1184, 0.1262], + [-0.0191, -0.0206, -0.0306], + [-0.0324, 0.0055, 0.1001], + [ 0.0955, 0.0659, -0.0545], + [-0.0504, 0.0231, -0.0013], + [ 0.0500, -0.0008, -0.0088], + [ 0.0982, 0.0941, 0.0976], + [-0.1233, -0.0280, -0.0897], + [-0.0005, -0.0530, -0.0020], + [-0.1273, -0.0932, -0.0680] + ] + def __call__(self, x): + return (x / self.scale_factor) + self.shift_factor + def go_back(self, x): + return (x - self.shift_factor) * self.scale_factor + + + +def check_is_comfy_lora(sd): + for k in sd: + if "lora_down" in k or "lora_up" in k: + return True + return False + +def comfy_to_xlabs_lora(sd): + sd_out = {} + for k in sd: + if "diffusion_model" in k: + new_k = (k + .replace(".lora_down.weight", ".down.weight") + .replace(".lora_up.weight", ".up.weight") + .replace(".img_attn.proj.", ".processor.proj_lora1.") + .replace(".txt_attn.proj.", ".processor.proj_lora2.") + .replace(".img_attn.qkv.", ".processor.qkv_lora1.") + .replace(".txt_attn.qkv.", ".processor.qkv_lora2.")) + new_k = new_k[len("diffusion_model."):] + else: + new_k=k + sd_out[new_k] = sd[k] + return sd_out + +def LinearStrengthModel(start, finish, size): + return [ + (start + (finish - start) * (i / (size - 1))) for i in range(size) + ] +def FirstHalfStrengthModel(start, finish, size): + sizehalf = size//2 + arr = [ + (start + (finish - start) * (i / (sizehalf - 1))) for i in range(sizehalf) + ] + return arr+[finish]*(size-sizehalf) +def SecondHalfStrengthModel(start, finish, size): + sizehalf = size//2 + arr = [ + (start + (finish - start) * (i / (sizehalf - 1))) for i in range(sizehalf) + ] + return [start]*(size-sizehalf)+arr +def SigmoidStrengthModel(start, finish, size): + def fade_out(x, x1, x2): + return 1 / (1 + np.exp(-(x - (x1 + x2) / 2) * 8 / (x2 - x1))) + arr = [start + (finish - start) * (fade_out(i, 0, size) - 0.5) for i in range(size)] + return arr diff --git a/x-flux-comfyui/workflows/canny_workflow.json b/x-flux-comfyui/workflows/canny_workflow.json new file mode 100644 index 0000000000000000000000000000000000000000..ce6e9a2b7ae80fa548f60d7ba99b2c595ab450ab --- /dev/null +++ b/x-flux-comfyui/workflows/canny_workflow.json @@ -0,0 +1,666 @@ +{ + "last_node_id": 22, + "last_link_id": 35, + "nodes": [ + { + "id": 8, + "type": "VAELoader", + "pos": [ + 1102, + 48 + ], + "size": { + "0": 315, + "1": 58 + }, + "flags": {}, + "order": 0, + "mode": 0, + "outputs": [ + { + "name": "VAE", + "type": "VAE", + "links": [ + 7 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAELoader" + }, + "widgets_values": [ + "ae.safetensors" + ] + }, + { + "id": 4, + "type": "DualCLIPLoader", + "pos": [ + -157, + 198 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 1, + "mode": 0, + "outputs": [ + { + "name": "CLIP", + "type": "CLIP", + "links": [ + 2, + 27 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "DualCLIPLoader" + }, + "widgets_values": [ + "clip_l.safetensors", + "t5xxl_fp16.safetensors", + "flux" + ] + }, + { + "id": 10, + "type": "UNETLoader", + "pos": [ + 243, + 590 + ], + "size": { + "0": 315, + "1": 82 + }, + "flags": {}, + "order": 2, + "mode": 0, + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 35 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "UNETLoader" + }, + "widgets_values": [ + "flux1-dev-fp8.safetensors", + "fp8_e4m3fn" + ] + }, + { + "id": 19, + "type": "CLIPTextEncodeFlux", + "pos": [ + 65, + 281 + ], + "size": { + "0": 400, + "1": 200 + }, + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 27, + "slot_index": 0 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 26 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncodeFlux" + }, + "widgets_values": [ + "", + "", + 4 + ] + }, + { + "id": 13, + "type": "LoadFluxControlNet", + "pos": [ + 4, + -226 + ], + "size": { + "0": 315, + "1": 82 + }, + "flags": {}, + "order": 3, + "mode": 0, + "outputs": [ + { + "name": "ControlNet", + "type": "FluxControlNet", + "links": [ + 19 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "LoadFluxControlNet" + }, + "widgets_values": [ + "flux-dev", + "flux-canny-controlnet.safetensors" + ] + }, + { + "id": 7, + "type": "VAEDecode", + "pos": [ + 1371, + 152 + ], + "size": { + "0": 210, + "1": 46 + }, + "flags": {}, + "order": 12, + "mode": 0, + "inputs": [ + { + "name": "samples", + "type": "LATENT", + "link": 6, + "slot_index": 0 + }, + { + "name": "vae", + "type": "VAE", + "link": 7 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 31 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAEDecode" + } + }, + { + "id": 3, + "type": "XlabsSampler", + "pos": [ + 1013, + 169 + ], + "size": { + "0": 342.5999755859375, + "1": 234 + }, + "flags": {}, + "order": 11, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 35, + "slot_index": 0 + }, + { + "name": "conditioning", + "type": "CONDITIONING", + "link": 18 + }, + { + "name": "neg_conditioning", + "type": "CONDITIONING", + "link": 26 + }, + { + "name": "latent_image", + "type": "LATENT", + "link": 5 + }, + { + "name": "controlnet_condition", + "type": "ControlNetCondition", + "link": 28 + } + ], + "outputs": [ + { + "name": "latent", + "type": "LATENT", + "links": [ + 6 + ], + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "XlabsSampler" + }, + "widgets_values": [ + 257762932021983, + "fixed", + 25, + 1, + 3.5 + ] + }, + { + "id": 15, + "type": "CannyEdgePreprocessor", + "pos": [ + -26, + -74 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "image", + "type": "IMAGE", + "link": 22 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 21, + 32 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CannyEdgePreprocessor" + }, + "widgets_values": [ + 100, + 200, + 832 + ] + }, + { + "id": 14, + "type": "ApplyFluxControlNet", + "pos": [ + 546, + -262 + ], + "size": { + "0": 393, + "1": 78 + }, + "flags": {}, + "order": 9, + "mode": 0, + "inputs": [ + { + "name": "controlnet", + "type": "FluxControlNet", + "link": 19 + }, + { + "name": "image", + "type": "IMAGE", + "link": 21, + "slot_index": 1 + } + ], + "outputs": [ + { + "name": "controlnet_condition", + "type": "ControlNetCondition", + "links": [ + 28 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "ApplyFluxControlNet" + }, + "widgets_values": [ + 0.8 + ] + }, + { + "id": 21, + "type": "PreviewImage", + "pos": [ + 1534, + 69 + ], + "size": { + "0": 210, + "1": 246 + }, + "flags": {}, + "order": 13, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 31, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PreviewImage" + } + }, + { + "id": 6, + "type": "EmptyLatentImage", + "pos": [ + 769, + 430 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 4, + "mode": 0, + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 5 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "EmptyLatentImage" + }, + "widgets_values": [ + 768, + 768, + 1 + ] + }, + { + "id": 5, + "type": "CLIPTextEncodeFlux", + "pos": [ + 518, + -63 + ], + "size": { + "0": 400, + "1": 200 + }, + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 2, + "slot_index": 0 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 18 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncodeFlux" + }, + "widgets_values": [ + "A charismatic speaker is captured mid-speech. He has long, slightly wavy blonde hair tied back in a ponytail. His expressive face, adorned with a salt-and-pepper beard and mustache, is animated as he gestures with his left hand, displaying a large ring on his pinky finger. He is holding a black microphone in his right hand, speaking passionately. The man is wearing a dark, textured shirt with unique, slightly shimmering patterns, and a green lanyard with multiple badges and logos hanging around his neck. The lanyard features the \"Autodesk\" and \"V-", + "A charismatic speaker is captured mid-speech. He has long, slightly wavy blonde hair tied back in a ponytail. His expressive face, adorned with a salt-and-pepper beard and mustache, is animated as he gestures with his left hand, displaying a large ring on his pinky finger. He is holding a black microphone in his right hand, speaking passionately. The man is wearing a dark, textured shirt with unique, slightly shimmering patterns, and a green lanyard with multiple badges and logos hanging around his neck. The lanyard features the \"Autodesk\" and \"V-", + 4 + ] + }, + { + "id": 17, + "type": "PreviewImage", + "pos": [ + 281, + 8 + ], + "size": { + "0": 210, + "1": 246 + }, + "flags": {}, + "order": 10, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 32, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PreviewImage" + } + }, + { + "id": 16, + "type": "LoadImage", + "pos": [ + -378, + -239 + ], + "size": { + "0": 315, + "1": 314 + }, + "flags": {}, + "order": 5, + "mode": 0, + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 22 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "example.jpg", + "image" + ] + } + ], + "links": [ + [ + 2, + 4, + 0, + 5, + 0, + "CLIP" + ], + [ + 5, + 6, + 0, + 3, + 3, + "LATENT" + ], + [ + 6, + 3, + 0, + 7, + 0, + "LATENT" + ], + [ + 7, + 8, + 0, + 7, + 1, + "VAE" + ], + [ + 18, + 5, + 0, + 3, + 1, + "CONDITIONING" + ], + [ + 19, + 13, + 0, + 14, + 0, + "FluxControlNet" + ], + [ + 21, + 15, + 0, + 14, + 1, + "IMAGE" + ], + [ + 22, + 16, + 0, + 15, + 0, + "IMAGE" + ], + [ + 26, + 19, + 0, + 3, + 2, + "CONDITIONING" + ], + [ + 27, + 4, + 0, + 19, + 0, + "CLIP" + ], + [ + 28, + 14, + 0, + 3, + 4, + "ControlNetCondition" + ], + [ + 31, + 7, + 0, + 21, + 0, + "IMAGE" + ], + [ + 32, + 15, + 0, + 17, + 0, + "IMAGE" + ], + [ + 35, + 10, + 0, + 3, + 0, + "MODEL" + ] + ], + "groups": [], + "config": {}, + "extra": { + "ds": { + "scale": 2.176291357901496, + "offset": [ + -1061.2497588685817, + 110.69101119830194 + ] + } + }, + "version": 0.4 +} \ No newline at end of file diff --git a/x-flux-comfyui/workflows/depth_workflow.json b/x-flux-comfyui/workflows/depth_workflow.json new file mode 100644 index 0000000000000000000000000000000000000000..ccd9aff1e7f05ad6d3070f38d2d688e6e0b9967c --- /dev/null +++ b/x-flux-comfyui/workflows/depth_workflow.json @@ -0,0 +1,666 @@ +{ + "last_node_id": 23, + "last_link_id": 40, + "nodes": [ + { + "id": 4, + "type": "DualCLIPLoader", + "pos": [ + -157, + 198 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 0, + "mode": 0, + "outputs": [ + { + "name": "CLIP", + "type": "CLIP", + "links": [ + 2, + 27 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "DualCLIPLoader" + }, + "widgets_values": [ + "clip_l.safetensors", + "t5xxl_fp16.safetensors", + "flux" + ] + }, + { + "id": 10, + "type": "UNETLoader", + "pos": [ + 243, + 590 + ], + "size": { + "0": 315, + "1": 82 + }, + "flags": {}, + "order": 1, + "mode": 0, + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 35 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "UNETLoader" + }, + "widgets_values": [ + "flux1-dev-fp8.safetensors", + "fp8_e4m3fn" + ] + }, + { + "id": 19, + "type": "CLIPTextEncodeFlux", + "pos": [ + 65, + 281 + ], + "size": { + "0": 400, + "1": 200 + }, + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 27, + "slot_index": 0 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 26 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncodeFlux" + }, + "widgets_values": [ + "", + "", + 4 + ] + }, + { + "id": 7, + "type": "VAEDecode", + "pos": [ + 1371, + 152 + ], + "size": { + "0": 210, + "1": 46 + }, + "flags": {}, + "order": 12, + "mode": 0, + "inputs": [ + { + "name": "samples", + "type": "LATENT", + "link": 6, + "slot_index": 0 + }, + { + "name": "vae", + "type": "VAE", + "link": 7 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 31 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAEDecode" + } + }, + { + "id": 13, + "type": "LoadFluxControlNet", + "pos": [ + 4, + -226 + ], + "size": { + "0": 315, + "1": 82 + }, + "flags": {}, + "order": 2, + "mode": 0, + "outputs": [ + { + "name": "ControlNet", + "type": "FluxControlNet", + "links": [ + 19 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "LoadFluxControlNet" + }, + "widgets_values": [ + "flux-dev", + "flux-depth-controlnet.safetensors" + ] + }, + { + "id": 8, + "type": "VAELoader", + "pos": [ + 1130, + 0 + ], + "size": { + "0": 315, + "1": 58 + }, + "flags": {}, + "order": 3, + "mode": 0, + "outputs": [ + { + "name": "VAE", + "type": "VAE", + "links": [ + 7 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAELoader" + }, + "widgets_values": [ + "ae.safetensors" + ] + }, + { + "id": 6, + "type": "EmptyLatentImage", + "pos": [ + 769, + 430 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 4, + "mode": 0, + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 5 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "EmptyLatentImage" + }, + "widgets_values": [ + 768, + 768, + 1 + ] + }, + { + "id": 3, + "type": "XlabsSampler", + "pos": [ + 1007, + 170 + ], + "size": { + "0": 342.5999755859375, + "1": 234 + }, + "flags": {}, + "order": 11, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 35, + "slot_index": 0 + }, + { + "name": "conditioning", + "type": "CONDITIONING", + "link": 18 + }, + { + "name": "neg_conditioning", + "type": "CONDITIONING", + "link": 26 + }, + { + "name": "latent_image", + "type": "LATENT", + "link": 5 + }, + { + "name": "controlnet_condition", + "type": "ControlNetCondition", + "link": 28 + } + ], + "outputs": [ + { + "name": "latent", + "type": "LATENT", + "links": [ + 6 + ], + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "XlabsSampler" + }, + "widgets_values": [ + 257762932021983, + "fixed", + 25, + 1, + 3.5 + ] + }, + { + "id": 17, + "type": "PreviewImage", + "pos": [ + 330, + 20 + ], + "size": { + "0": 210, + "1": 246 + }, + "flags": {}, + "order": 9, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 38, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PreviewImage" + } + }, + { + "id": 23, + "type": "MiDaS-DepthMapPreprocessor", + "pos": [ + -27, + -50 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "image", + "type": "IMAGE", + "link": 37 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 38, + 40 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "MiDaS-DepthMapPreprocessor" + }, + "widgets_values": [ + 6.283185307179586, + 0.1, + 768 + ] + }, + { + "id": 21, + "type": "PreviewImage", + "pos": [ + 1519, + 134 + ], + "size": { + "0": 210, + "1": 246 + }, + "flags": {}, + "order": 13, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 31, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PreviewImage" + } + }, + { + "id": 14, + "type": "ApplyFluxControlNet", + "pos": [ + 546, + -262 + ], + "size": { + "0": 393, + "1": 78 + }, + "flags": {}, + "order": 10, + "mode": 0, + "inputs": [ + { + "name": "controlnet", + "type": "FluxControlNet", + "link": 19 + }, + { + "name": "image", + "type": "IMAGE", + "link": 40, + "slot_index": 1 + } + ], + "outputs": [ + { + "name": "controlnet_condition", + "type": "ControlNetCondition", + "links": [ + 28 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "ApplyFluxControlNet" + }, + "widgets_values": [ + 0.86 + ] + }, + { + "id": 5, + "type": "CLIPTextEncodeFlux", + "pos": [ + 518, + -63 + ], + "size": { + "0": 400, + "1": 200 + }, + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 2, + "slot_index": 0 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 18 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncodeFlux" + }, + "widgets_values": [ + "man with microphone in the office, anime", + "man with microphone in the office, anime", + 4 + ] + }, + { + "id": 16, + "type": "LoadImage", + "pos": [ + -378, + -239 + ], + "size": { + "0": 315, + "1": 314 + }, + "flags": {}, + "order": 5, + "mode": 0, + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 37 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "example.jpg", + "image" + ] + } + ], + "links": [ + [ + 2, + 4, + 0, + 5, + 0, + "CLIP" + ], + [ + 5, + 6, + 0, + 3, + 3, + "LATENT" + ], + [ + 6, + 3, + 0, + 7, + 0, + "LATENT" + ], + [ + 7, + 8, + 0, + 7, + 1, + "VAE" + ], + [ + 18, + 5, + 0, + 3, + 1, + "CONDITIONING" + ], + [ + 19, + 13, + 0, + 14, + 0, + "FluxControlNet" + ], + [ + 26, + 19, + 0, + 3, + 2, + "CONDITIONING" + ], + [ + 27, + 4, + 0, + 19, + 0, + "CLIP" + ], + [ + 28, + 14, + 0, + 3, + 4, + "ControlNetCondition" + ], + [ + 31, + 7, + 0, + 21, + 0, + "IMAGE" + ], + [ + 35, + 10, + 0, + 3, + 0, + "MODEL" + ], + [ + 37, + 16, + 0, + 23, + 0, + "IMAGE" + ], + [ + 38, + 23, + 0, + 17, + 0, + "IMAGE" + ], + [ + 40, + 23, + 0, + 14, + 1, + "IMAGE" + ] + ], + "groups": [], + "config": {}, + "extra": { + "ds": { + "scale": 1.3513057093103997, + "offset": [ + -758.5037788677209, + 160.33608624136815 + ] + } + }, + "version": 0.4 +} \ No newline at end of file diff --git a/x-flux-comfyui/workflows/example.jpg b/x-flux-comfyui/workflows/example.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4eff2feb0207272dd3c043ac076adfe5df98b778 --- /dev/null +++ b/x-flux-comfyui/workflows/example.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:79d18658d7caecb446ab595ea3142104d66fc727f1bb39e293bfb1193168777f +size 150944 diff --git a/x-flux-comfyui/workflows/flux-controlnet-canny-v3-workflow.json b/x-flux-comfyui/workflows/flux-controlnet-canny-v3-workflow.json new file mode 100644 index 0000000000000000000000000000000000000000..9913c5d8e88f98854cbef77ed6474bcee1f73f89 --- /dev/null +++ b/x-flux-comfyui/workflows/flux-controlnet-canny-v3-workflow.json @@ -0,0 +1,738 @@ +{ + "last_node_id": 48, + "last_link_id": 115, + "nodes": [ + { + "id": 14, + "type": "ApplyFluxControlNet", + "pos": [ + 546, + -264 + ], + "size": { + "0": 393, + "1": 78 + }, + "flags": {}, + "order": 10, + "mode": 0, + "inputs": [ + { + "name": "controlnet", + "type": "FluxControlNet", + "link": 44 + }, + { + "name": "image", + "type": "IMAGE", + "link": 114, + "slot_index": 1 + } + ], + "outputs": [ + { + "name": "controlnet_condition", + "type": "ControlNetCondition", + "links": [ + 28 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "ApplyFluxControlNet" + }, + "widgets_values": [ + 0.7000000000000001 + ] + }, + { + "id": 39, + "type": "PreviewImage", + "pos": [ + 444, + -130 + ], + "size": { + "0": 210, + "1": 246 + }, + "flags": { + "collapsed": false + }, + "order": 9, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 113 + } + ], + "properties": { + "Node name for S&R": "PreviewImage" + } + }, + { + "id": 46, + "type": "SaveImage", + "pos": [ + 621, + 146 + ], + "size": { + "0": 315, + "1": 270 + }, + "flags": {}, + "order": 11, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 115 + } + ], + "properties": {}, + "widgets_values": [ + "canny_process" + ] + }, + { + "id": 21, + "type": "PreviewImage", + "pos": [ + 1088, + -373 + ], + "size": { + "0": 210, + "1": 246 + }, + "flags": {}, + "order": 14, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 79, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PreviewImage" + } + }, + { + "id": 48, + "type": "CannyEdgePreprocessor", + "pos": [ + 102, + -227 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "image", + "type": "IMAGE", + "link": 112 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 113, + 114, + 115 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CannyEdgePreprocessor" + }, + "widgets_values": [ + 100, + 200, + 1024 + ] + }, + { + "id": 13, + "type": "LoadFluxControlNet", + "pos": [ + 121, + -375 + ], + "size": { + "0": 316.83343505859375, + "1": 86.47058868408203 + }, + "flags": {}, + "order": 0, + "mode": 0, + "outputs": [ + { + "name": "ControlNet", + "type": "FluxControlNet", + "links": [ + 44 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "LoadFluxControlNet" + }, + "widgets_values": [ + "flux-dev", + "flux-canny-controlnet-v3.safetensors" + ] + }, + { + "id": 4, + "type": "DualCLIPLoader", + "pos": [ + 104, + -79 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 1, + "mode": 0, + "outputs": [ + { + "name": "CLIP", + "type": "CLIP", + "links": [ + 2, + 27 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "DualCLIPLoader" + }, + "widgets_values": [ + "clip_l.safetensors", + "t5xxl_fp16.safetensors", + "flux" + ] + }, + { + "id": 19, + "type": "CLIPTextEncodeFlux", + "pos": [ + 203, + 167 + ], + "size": { + "0": 400, + "1": 200 + }, + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 27, + "slot_index": 0 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 26 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncodeFlux" + }, + "widgets_values": [ + "bad photo", + "bad photo", + 4 + ] + }, + { + "id": 32, + "type": "UNETLoader", + "pos": [ + 502, + 452 + ], + "size": { + "0": 315, + "1": 82 + }, + "flags": {}, + "order": 2, + "mode": 0, + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 58 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "UNETLoader" + }, + "widgets_values": [ + "flux1-dev-fp8.safetensors", + "fp8_e4m3fn" + ] + }, + { + "id": 6, + "type": "EmptyLatentImage", + "pos": [ + 850, + 449 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 3, + "mode": 0, + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 66 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "EmptyLatentImage" + }, + "widgets_values": [ + 1024, + 1024, + 1 + ] + }, + { + "id": 7, + "type": "VAEDecode", + "pos": [ + 1164, + 40 + ], + "size": { + "0": 210, + "1": 46 + }, + "flags": {}, + "order": 13, + "mode": 0, + "inputs": [ + { + "name": "samples", + "type": "LATENT", + "link": 6, + "slot_index": 0 + }, + { + "name": "vae", + "type": "VAE", + "link": 7 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 79, + 101 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAEDecode" + } + }, + { + "id": 16, + "type": "LoadImage", + "pos": [ + -253, + -229 + ], + "size": { + "0": 315, + "1": 314 + }, + "flags": {}, + "order": 4, + "mode": 0, + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 112 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "dining room of a modern brutalist house in the eng.webp", + "image" + ] + }, + { + "id": 23, + "type": "SaveImage", + "pos": [ + 1441, + -122 + ], + "size": { + "0": 356.1513671875, + "1": 270 + }, + "flags": {}, + "order": 15, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 101 + } + ], + "properties": {}, + "widgets_values": [ + "canny_res" + ] + }, + { + "id": 8, + "type": "VAELoader", + "pos": [ + 1111, + -84 + ], + "size": { + "0": 315, + "1": 58 + }, + "flags": {}, + "order": 5, + "mode": 0, + "outputs": [ + { + "name": "VAE", + "type": "VAE", + "links": [ + 7 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAELoader" + }, + "widgets_values": [ + "ae.safetensors" + ] + }, + { + "id": 5, + "type": "CLIPTextEncodeFlux", + "pos": [ + 688, + -110 + ], + "size": { + "0": 400, + "1": 200 + }, + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 2, + "slot_index": 0 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 18 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncodeFlux" + }, + "widgets_values": [ + "cyberpank dining room, full hd, cinematic", + "cyberpank dining room, full hd, cinematic", + 4 + ] + }, + { + "id": 3, + "type": "XlabsSampler", + "pos": [ + 948, + 149 + ], + "size": { + "0": 342.5999755859375, + "1": 258 + }, + "flags": {}, + "order": 12, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 58, + "slot_index": 0 + }, + { + "name": "conditioning", + "type": "CONDITIONING", + "link": 18 + }, + { + "name": "neg_conditioning", + "type": "CONDITIONING", + "link": 26 + }, + { + "name": "latent_image", + "type": "LATENT", + "link": 66 + }, + { + "name": "controlnet_condition", + "type": "ControlNetCondition", + "link": 28 + } + ], + "outputs": [ + { + "name": "latent", + "type": "LATENT", + "links": [ + 6 + ], + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "XlabsSampler" + }, + "widgets_values": [ + 324242532548, + "fixed", + 25, + 1, + 3.5, + 0 + ] + } + ], + "links": [ + [ + 2, + 4, + 0, + 5, + 0, + "CLIP" + ], + [ + 6, + 3, + 0, + 7, + 0, + "LATENT" + ], + [ + 7, + 8, + 0, + 7, + 1, + "VAE" + ], + [ + 18, + 5, + 0, + 3, + 1, + "CONDITIONING" + ], + [ + 26, + 19, + 0, + 3, + 2, + "CONDITIONING" + ], + [ + 27, + 4, + 0, + 19, + 0, + "CLIP" + ], + [ + 28, + 14, + 0, + 3, + 4, + "ControlNetCondition" + ], + [ + 44, + 13, + 0, + 14, + 0, + "FluxControlNet" + ], + [ + 58, + 32, + 0, + 3, + 0, + "MODEL" + ], + [ + 66, + 6, + 0, + 3, + 3, + "LATENT" + ], + [ + 79, + 7, + 0, + 21, + 0, + "IMAGE" + ], + [ + 101, + 7, + 0, + 23, + 0, + "IMAGE" + ], + [ + 112, + 16, + 0, + 48, + 0, + "IMAGE" + ], + [ + 113, + 48, + 0, + 39, + 0, + "IMAGE" + ], + [ + 114, + 48, + 0, + 14, + 1, + "IMAGE" + ], + [ + 115, + 48, + 0, + 46, + 0, + "IMAGE" + ] + ], + "groups": [], + "config": {}, + "extra": { + "ds": { + "scale": 0.9229599817706443, + "offset": [ + 305.8091888316129, + 456.5666981874018 + ] + } + }, + "version": 0.4 +} \ No newline at end of file diff --git a/x-flux-comfyui/workflows/flux-controlnet-depth-v3-workflow.json b/x-flux-comfyui/workflows/flux-controlnet-depth-v3-workflow.json new file mode 100644 index 0000000000000000000000000000000000000000..a37f5d37c1855386912adbfd66463e9d74e0edb6 --- /dev/null +++ b/x-flux-comfyui/workflows/flux-controlnet-depth-v3-workflow.json @@ -0,0 +1,738 @@ +{ + "last_node_id": 49, + "last_link_id": 122, + "nodes": [ + { + "id": 14, + "type": "ApplyFluxControlNet", + "pos": [ + 546, + -264 + ], + "size": { + "0": 393, + "1": 78 + }, + "flags": {}, + "order": 10, + "mode": 0, + "inputs": [ + { + "name": "controlnet", + "type": "FluxControlNet", + "link": 44 + }, + { + "name": "image", + "type": "IMAGE", + "link": 121, + "slot_index": 1 + } + ], + "outputs": [ + { + "name": "controlnet_condition", + "type": "ControlNetCondition", + "links": [ + 28 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "ApplyFluxControlNet" + }, + "widgets_values": [ + 0.7000000000000001 + ] + }, + { + "id": 3, + "type": "XlabsSampler", + "pos": [ + 1119, + 136 + ], + "size": { + "0": 342.5999755859375, + "1": 258 + }, + "flags": {}, + "order": 12, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 58, + "slot_index": 0 + }, + { + "name": "conditioning", + "type": "CONDITIONING", + "link": 18 + }, + { + "name": "neg_conditioning", + "type": "CONDITIONING", + "link": 26 + }, + { + "name": "latent_image", + "type": "LATENT", + "link": 66 + }, + { + "name": "controlnet_condition", + "type": "ControlNetCondition", + "link": 28 + } + ], + "outputs": [ + { + "name": "latent", + "type": "LATENT", + "links": [ + 6 + ], + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "XlabsSampler" + }, + "widgets_values": [ + 324242532548, + "fixed", + 25, + 1, + 3.5, + 0 + ] + }, + { + "id": 5, + "type": "CLIPTextEncodeFlux", + "pos": [ + 706, + -83 + ], + "size": { + "0": 400, + "1": 200 + }, + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 2, + "slot_index": 0 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 18 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncodeFlux" + }, + "widgets_values": [ + "handsome man in balenciaga style, fashion, vogue image", + "handsome man in balenciaga style, fashion, vogue image", + 4 + ] + }, + { + "id": 13, + "type": "LoadFluxControlNet", + "pos": [ + 102, + -376 + ], + "size": { + "0": 316.83343505859375, + "1": 86.47058868408203 + }, + "flags": {}, + "order": 0, + "mode": 0, + "outputs": [ + { + "name": "ControlNet", + "type": "FluxControlNet", + "links": [ + 44 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "LoadFluxControlNet" + }, + "widgets_values": [ + "flux-dev", + "flux-depth-controlnet-v3.safetensors" + ] + }, + { + "id": 46, + "type": "SaveImage", + "pos": [ + 630, + 160 + ], + "size": { + "0": 315, + "1": 270 + }, + "flags": {}, + "order": 9, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 120 + } + ], + "properties": {}, + "widgets_values": [ + "canny_process" + ] + }, + { + "id": 19, + "type": "CLIPTextEncodeFlux", + "pos": [ + 210, + 180 + ], + "size": { + "0": 400, + "1": 200 + }, + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 27, + "slot_index": 0 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 26 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncodeFlux" + }, + "widgets_values": [ + "bad photo", + "bad photo", + 4 + ] + }, + { + "id": 7, + "type": "VAEDecode", + "pos": [ + 1160, + 50 + ], + "size": { + "0": 210, + "1": 46 + }, + "flags": {}, + "order": 13, + "mode": 0, + "inputs": [ + { + "name": "samples", + "type": "LATENT", + "link": 6, + "slot_index": 0 + }, + { + "name": "vae", + "type": "VAE", + "link": 7 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 79, + 101 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAEDecode" + } + }, + { + "id": 32, + "type": "UNETLoader", + "pos": [ + 490, + 480 + ], + "size": { + "0": 315, + "1": 82 + }, + "flags": {}, + "order": 1, + "mode": 0, + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 58 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "UNETLoader" + }, + "widgets_values": [ + "flux1-dev-fp8.safetensors", + "fp8_e4m3fn" + ] + }, + { + "id": 6, + "type": "EmptyLatentImage", + "pos": [ + 910, + 380 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 2, + "mode": 0, + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 66 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "EmptyLatentImage" + }, + "widgets_values": [ + 1024, + 1024, + 1 + ] + }, + { + "id": 21, + "type": "PreviewImage", + "pos": [ + 1090, + -340 + ], + "size": { + "0": 210, + "1": 246 + }, + "flags": {}, + "order": 14, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 79, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PreviewImage" + } + }, + { + "id": 23, + "type": "SaveImage", + "pos": [ + 1420, + -190 + ], + "size": { + "0": 356.1513671875, + "1": 270 + }, + "flags": {}, + "order": 15, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 101 + } + ], + "properties": {}, + "widgets_values": [ + "canny_res" + ] + }, + { + "id": 39, + "type": "PreviewImage", + "pos": [ + 470, + -130 + ], + "size": { + "0": 210, + "1": 246 + }, + "flags": { + "collapsed": false + }, + "order": 11, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 122 + } + ], + "properties": { + "Node name for S&R": "PreviewImage" + } + }, + { + "id": 4, + "type": "DualCLIPLoader", + "pos": [ + 130, + -20 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 3, + "mode": 0, + "outputs": [ + { + "name": "CLIP", + "type": "CLIP", + "links": [ + 2, + 27 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "DualCLIPLoader" + }, + "widgets_values": [ + "clip_l.safetensors", + "t5xxl_fp16.safetensors", + "flux" + ] + }, + { + "id": 49, + "type": "MiDaS-DepthMapPreprocessor", + "pos": [ + 120, + -190 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "image", + "type": "IMAGE", + "link": 118 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 120, + 121, + 122 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "MiDaS-DepthMapPreprocessor" + }, + "widgets_values": [ + 6.283185307179586, + 0.1, + 1024 + ] + }, + { + "id": 16, + "type": "LoadImage", + "pos": [ + -210, + -330 + ], + "size": { + "0": 315, + "1": 314 + }, + "flags": {}, + "order": 4, + "mode": 0, + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 118 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "input_pose_2.png", + "image" + ] + }, + { + "id": 8, + "type": "VAELoader", + "pos": [ + 1100, + -50 + ], + "size": { + "0": 315, + "1": 58 + }, + "flags": {}, + "order": 5, + "mode": 0, + "outputs": [ + { + "name": "VAE", + "type": "VAE", + "links": [ + 7 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAELoader" + }, + "widgets_values": [ + "ae.safetensors" + ] + } + ], + "links": [ + [ + 2, + 4, + 0, + 5, + 0, + "CLIP" + ], + [ + 6, + 3, + 0, + 7, + 0, + "LATENT" + ], + [ + 7, + 8, + 0, + 7, + 1, + "VAE" + ], + [ + 18, + 5, + 0, + 3, + 1, + "CONDITIONING" + ], + [ + 26, + 19, + 0, + 3, + 2, + "CONDITIONING" + ], + [ + 27, + 4, + 0, + 19, + 0, + "CLIP" + ], + [ + 28, + 14, + 0, + 3, + 4, + "ControlNetCondition" + ], + [ + 44, + 13, + 0, + 14, + 0, + "FluxControlNet" + ], + [ + 58, + 32, + 0, + 3, + 0, + "MODEL" + ], + [ + 66, + 6, + 0, + 3, + 3, + "LATENT" + ], + [ + 79, + 7, + 0, + 21, + 0, + "IMAGE" + ], + [ + 101, + 7, + 0, + 23, + 0, + "IMAGE" + ], + [ + 118, + 16, + 0, + 49, + 0, + "IMAGE" + ], + [ + 120, + 49, + 0, + 46, + 0, + "IMAGE" + ], + [ + 121, + 49, + 0, + 14, + 1, + "IMAGE" + ], + [ + 122, + 49, + 0, + 39, + 0, + "IMAGE" + ] + ], + "groups": [], + "config": {}, + "extra": { + "ds": { + "scale": 0.6934334949441344, + "offset": [ + 610.0322939307831, + 836.0918165037601 + ] + } + }, + "version": 0.4 +} \ No newline at end of file diff --git a/x-flux-comfyui/workflows/flux-controlnet-hed-v3-workflow.json b/x-flux-comfyui/workflows/flux-controlnet-hed-v3-workflow.json new file mode 100644 index 0000000000000000000000000000000000000000..7aee510e624de76e0bb68d8c7642ddb977e791fd --- /dev/null +++ b/x-flux-comfyui/workflows/flux-controlnet-hed-v3-workflow.json @@ -0,0 +1,737 @@ +{ + "last_node_id": 47, + "last_link_id": 111, + "nodes": [ + { + "id": 14, + "type": "ApplyFluxControlNet", + "pos": [ + 505, + -378 + ], + "size": { + "0": 393, + "1": 78 + }, + "flags": {}, + "order": 10, + "mode": 0, + "inputs": [ + { + "name": "controlnet", + "type": "FluxControlNet", + "link": 44 + }, + { + "name": "image", + "type": "IMAGE", + "link": 110, + "slot_index": 1 + } + ], + "outputs": [ + { + "name": "controlnet_condition", + "type": "ControlNetCondition", + "links": [ + 28 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "ApplyFluxControlNet" + }, + "widgets_values": [ + 0.7000000000000001 + ] + }, + { + "id": 6, + "type": "EmptyLatentImage", + "pos": [ + 669, + 150 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 0, + "mode": 0, + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 66 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "EmptyLatentImage" + }, + "widgets_values": [ + 1024, + 1024, + 1 + ] + }, + { + "id": 46, + "type": "SaveImage", + "pos": [ + 969, + 278 + ], + "size": { + "0": 315, + "1": 270 + }, + "flags": {}, + "order": 11, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 111 + } + ], + "properties": {}, + "widgets_values": [ + "canny_process" + ] + }, + { + "id": 5, + "type": "CLIPTextEncodeFlux", + "pos": [ + 556, + -238 + ], + "size": { + "0": 400, + "1": 200 + }, + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 2, + "slot_index": 0 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 18 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncodeFlux" + }, + "widgets_values": [ + "A beautiful woman with white hair and light freckles, her neck area bare and visible, capturing attention with her attitude, has modeled for an editorial magazine, captured in full body, fashion photography, within the scope of high future fashion, photographed by Alessio Albi.", + "A beautiful woman with white hair and light freckles, her neck area bare and visible, capturing attention with her attitude, has modeled for an editorial magazine, captured in full body, fashion photography, within the scope of high future fashion, photographed by Alessio Albi.", + 4 + ] + }, + { + "id": 39, + "type": "PreviewImage", + "pos": [ + 409, + -141 + ], + "size": { + "0": 210, + "1": 246 + }, + "flags": { + "collapsed": false + }, + "order": 9, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 109 + } + ], + "properties": { + "Node name for S&R": "PreviewImage" + } + }, + { + "id": 8, + "type": "VAELoader", + "pos": [ + 946, + -361 + ], + "size": { + "0": 315, + "1": 58 + }, + "flags": {}, + "order": 1, + "mode": 0, + "outputs": [ + { + "name": "VAE", + "type": "VAE", + "links": [ + 7 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAELoader" + }, + "widgets_values": [ + "ae.safetensors" + ] + }, + { + "id": 7, + "type": "VAEDecode", + "pos": [ + 986, + -227 + ], + "size": { + "0": 210, + "1": 46 + }, + "flags": {}, + "order": 13, + "mode": 0, + "inputs": [ + { + "name": "samples", + "type": "LATENT", + "link": 6, + "slot_index": 0 + }, + { + "name": "vae", + "type": "VAE", + "link": 7 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 79, + 101 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAEDecode" + } + }, + { + "id": 21, + "type": "PreviewImage", + "pos": [ + 1322, + -365 + ], + "size": { + "0": 210, + "1": 246 + }, + "flags": {}, + "order": 14, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 79, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PreviewImage" + } + }, + { + "id": 19, + "type": "CLIPTextEncodeFlux", + "pos": [ + 182, + 146 + ], + "size": { + "0": 400, + "1": 200 + }, + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 27, + "slot_index": 0 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 26 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncodeFlux" + }, + "widgets_values": [ + "bad photo", + "bad photo", + 4 + ] + }, + { + "id": 32, + "type": "UNETLoader", + "pos": [ + 634, + 305 + ], + "size": { + "0": 315, + "1": 82 + }, + "flags": {}, + "order": 2, + "mode": 0, + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 58 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "UNETLoader" + }, + "widgets_values": [ + "flux1-dev-fp8.safetensors", + "fp8_e4m3fn" + ] + }, + { + "id": 4, + "type": "DualCLIPLoader", + "pos": [ + 0, + -36 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 3, + "mode": 0, + "outputs": [ + { + "name": "CLIP", + "type": "CLIP", + "links": [ + 2, + 27 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "DualCLIPLoader" + }, + "widgets_values": [ + "clip_l.safetensors", + "t5xxl_fp16.safetensors", + "flux" + ] + }, + { + "id": 47, + "type": "HEDPreprocessor", + "pos": [ + 92, + -170 + ], + "size": { + "0": 315, + "1": 82 + }, + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "image", + "type": "IMAGE", + "link": 108 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 109, + 110, + 111 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "HEDPreprocessor" + }, + "widgets_values": [ + "enable", + 1024 + ] + }, + { + "id": 13, + "type": "LoadFluxControlNet", + "pos": [ + 147, + -366 + ], + "size": { + "0": 316.83343505859375, + "1": 86.47058868408203 + }, + "flags": {}, + "order": 4, + "mode": 0, + "outputs": [ + { + "name": "ControlNet", + "type": "FluxControlNet", + "links": [ + 44 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "LoadFluxControlNet" + }, + "widgets_values": [ + "flux-dev", + "flux-hed-controlnet-v3.safetensors" + ] + }, + { + "id": 16, + "type": "LoadImage", + "pos": [ + -236, + -397 + ], + "size": { + "0": 315, + "1": 314 + }, + "flags": {}, + "order": 5, + "mode": 0, + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 108 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "A beautiful woman with white hair and light freckl.webp", + "image" + ] + }, + { + "id": 3, + "type": "XlabsSampler", + "pos": [ + 964, + -75 + ], + "size": { + "0": 342.5999755859375, + "1": 258 + }, + "flags": {}, + "order": 12, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 58, + "slot_index": 0 + }, + { + "name": "conditioning", + "type": "CONDITIONING", + "link": 18 + }, + { + "name": "neg_conditioning", + "type": "CONDITIONING", + "link": 26 + }, + { + "name": "latent_image", + "type": "LATENT", + "link": 66 + }, + { + "name": "controlnet_condition", + "type": "ControlNetCondition", + "link": 28 + } + ], + "outputs": [ + { + "name": "latent", + "type": "LATENT", + "links": [ + 6 + ], + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "XlabsSampler" + }, + "widgets_values": [ + 324242532548, + "fixed", + 25, + 1, + 3.5, + 0 + ] + }, + { + "id": 23, + "type": "SaveImage", + "pos": [ + 1322, + -64 + ], + "size": { + "0": 356.1513671875, + "1": 270 + }, + "flags": {}, + "order": 15, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 101 + } + ], + "properties": {}, + "widgets_values": [ + "canny_res" + ] + } + ], + "links": [ + [ + 2, + 4, + 0, + 5, + 0, + "CLIP" + ], + [ + 6, + 3, + 0, + 7, + 0, + "LATENT" + ], + [ + 7, + 8, + 0, + 7, + 1, + "VAE" + ], + [ + 18, + 5, + 0, + 3, + 1, + "CONDITIONING" + ], + [ + 26, + 19, + 0, + 3, + 2, + "CONDITIONING" + ], + [ + 27, + 4, + 0, + 19, + 0, + "CLIP" + ], + [ + 28, + 14, + 0, + 3, + 4, + "ControlNetCondition" + ], + [ + 44, + 13, + 0, + 14, + 0, + "FluxControlNet" + ], + [ + 58, + 32, + 0, + 3, + 0, + "MODEL" + ], + [ + 66, + 6, + 0, + 3, + 3, + "LATENT" + ], + [ + 79, + 7, + 0, + 21, + 0, + "IMAGE" + ], + [ + 101, + 7, + 0, + 23, + 0, + "IMAGE" + ], + [ + 108, + 16, + 0, + 47, + 0, + "IMAGE" + ], + [ + 109, + 47, + 0, + 39, + 0, + "IMAGE" + ], + [ + 110, + 47, + 0, + 14, + 1, + "IMAGE" + ], + [ + 111, + 47, + 0, + 46, + 0, + "IMAGE" + ] + ], + "groups": [], + "config": {}, + "extra": { + "ds": { + "scale": 0.922959981770646, + "offset": [ + 263.90086107482557, + 618.1755731182905 + ] + } + }, + "version": 0.4 +} \ No newline at end of file diff --git a/x-flux-comfyui/workflows/hed_workflow.json b/x-flux-comfyui/workflows/hed_workflow.json new file mode 100644 index 0000000000000000000000000000000000000000..3728dbb7eea50f6cffd106306ac9fe7f990df4bb --- /dev/null +++ b/x-flux-comfyui/workflows/hed_workflow.json @@ -0,0 +1,665 @@ +{ + "last_node_id": 24, + "last_link_id": 44, + "nodes": [ + { + "id": 4, + "type": "DualCLIPLoader", + "pos": [ + -157, + 198 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 0, + "mode": 0, + "outputs": [ + { + "name": "CLIP", + "type": "CLIP", + "links": [ + 2, + 27 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "DualCLIPLoader" + }, + "widgets_values": [ + "clip_l.safetensors", + "t5xxl_fp16.safetensors", + "flux" + ] + }, + { + "id": 10, + "type": "UNETLoader", + "pos": [ + 243, + 590 + ], + "size": { + "0": 315, + "1": 82 + }, + "flags": {}, + "order": 1, + "mode": 0, + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 35 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "UNETLoader" + }, + "widgets_values": [ + "flux1-dev-fp8.safetensors", + "fp8_e4m3fn" + ] + }, + { + "id": 19, + "type": "CLIPTextEncodeFlux", + "pos": [ + 65, + 281 + ], + "size": { + "0": 400, + "1": 200 + }, + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 27, + "slot_index": 0 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 26 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncodeFlux" + }, + "widgets_values": [ + "", + "", + 4 + ] + }, + { + "id": 7, + "type": "VAEDecode", + "pos": [ + 1371, + 152 + ], + "size": { + "0": 210, + "1": 46 + }, + "flags": {}, + "order": 12, + "mode": 0, + "inputs": [ + { + "name": "samples", + "type": "LATENT", + "link": 6, + "slot_index": 0 + }, + { + "name": "vae", + "type": "VAE", + "link": 7 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 31 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAEDecode" + } + }, + { + "id": 8, + "type": "VAELoader", + "pos": [ + 1130, + 0 + ], + "size": { + "0": 315, + "1": 58 + }, + "flags": {}, + "order": 2, + "mode": 0, + "outputs": [ + { + "name": "VAE", + "type": "VAE", + "links": [ + 7 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAELoader" + }, + "widgets_values": [ + "ae.safetensors" + ] + }, + { + "id": 13, + "type": "LoadFluxControlNet", + "pos": [ + 6, + -226 + ], + "size": { + "0": 315, + "1": 82 + }, + "flags": {}, + "order": 3, + "mode": 0, + "outputs": [ + { + "name": "ControlNet", + "type": "FluxControlNet", + "links": [ + 19 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "LoadFluxControlNet" + }, + "widgets_values": [ + "flux-dev", + "flux-hed-controlnet.safetensors" + ] + }, + { + "id": 24, + "type": "HEDPreprocessor", + "pos": [ + -53, + -8 + ], + "size": { + "0": 315, + "1": 82 + }, + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "image", + "type": "IMAGE", + "link": 42 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 43, + 44 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "HEDPreprocessor" + }, + "widgets_values": [ + "enable", + 768 + ] + }, + { + "id": 17, + "type": "PreviewImage", + "pos": [ + 288, + -44 + ], + "size": { + "0": 210, + "1": 246 + }, + "flags": {}, + "order": 9, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 43, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PreviewImage" + } + }, + { + "id": 3, + "type": "XlabsSampler", + "pos": [ + 1007, + 170 + ], + "size": { + "0": 342.5999755859375, + "1": 234 + }, + "flags": {}, + "order": 11, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 35, + "slot_index": 0 + }, + { + "name": "conditioning", + "type": "CONDITIONING", + "link": 18 + }, + { + "name": "neg_conditioning", + "type": "CONDITIONING", + "link": 26 + }, + { + "name": "latent_image", + "type": "LATENT", + "link": 5 + }, + { + "name": "controlnet_condition", + "type": "ControlNetCondition", + "link": 28 + } + ], + "outputs": [ + { + "name": "latent", + "type": "LATENT", + "links": [ + 6 + ], + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "XlabsSampler" + }, + "widgets_values": [ + 257762932021984, + "fixed", + 25, + 1, + 3.5 + ] + }, + { + "id": 5, + "type": "CLIPTextEncodeFlux", + "pos": [ + 518, + -63 + ], + "size": { + "0": 400, + "1": 200 + }, + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 2, + "slot_index": 0 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 18 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncodeFlux" + }, + "widgets_values": [ + "man with microphone in the desert, disney", + "man with microphone in the desert, disney", + 4 + ] + }, + { + "id": 21, + "type": "PreviewImage", + "pos": [ + 1615, + 99 + ], + "size": { + "0": 210, + "1": 246 + }, + "flags": {}, + "order": 13, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 31, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PreviewImage" + } + }, + { + "id": 14, + "type": "ApplyFluxControlNet", + "pos": [ + 546, + -262 + ], + "size": { + "0": 393, + "1": 78 + }, + "flags": {}, + "order": 10, + "mode": 0, + "inputs": [ + { + "name": "controlnet", + "type": "FluxControlNet", + "link": 19 + }, + { + "name": "image", + "type": "IMAGE", + "link": 44, + "slot_index": 1 + } + ], + "outputs": [ + { + "name": "controlnet_condition", + "type": "ControlNetCondition", + "links": [ + 28 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "ApplyFluxControlNet" + }, + "widgets_values": [ + 0.7000000000000001 + ] + }, + { + "id": 6, + "type": "EmptyLatentImage", + "pos": [ + 769, + 430 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 4, + "mode": 0, + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 5 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "EmptyLatentImage" + }, + "widgets_values": [ + 1024, + 1024, + 1 + ] + }, + { + "id": 16, + "type": "LoadImage", + "pos": [ + -378, + -239 + ], + "size": { + "0": 315, + "1": 314 + }, + "flags": {}, + "order": 5, + "mode": 0, + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 42 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "example.jpg", + "image" + ] + } + ], + "links": [ + [ + 2, + 4, + 0, + 5, + 0, + "CLIP" + ], + [ + 5, + 6, + 0, + 3, + 3, + "LATENT" + ], + [ + 6, + 3, + 0, + 7, + 0, + "LATENT" + ], + [ + 7, + 8, + 0, + 7, + 1, + "VAE" + ], + [ + 18, + 5, + 0, + 3, + 1, + "CONDITIONING" + ], + [ + 19, + 13, + 0, + 14, + 0, + "FluxControlNet" + ], + [ + 26, + 19, + 0, + 3, + 2, + "CONDITIONING" + ], + [ + 27, + 4, + 0, + 19, + 0, + "CLIP" + ], + [ + 28, + 14, + 0, + 3, + 4, + "ControlNetCondition" + ], + [ + 31, + 7, + 0, + 21, + 0, + "IMAGE" + ], + [ + 35, + 10, + 0, + 3, + 0, + "MODEL" + ], + [ + 42, + 16, + 0, + 24, + 0, + "IMAGE" + ], + [ + 43, + 24, + 0, + 17, + 0, + "IMAGE" + ], + [ + 44, + 24, + 0, + 14, + 1, + "IMAGE" + ] + ], + "groups": [], + "config": {}, + "extra": { + "ds": { + "scale": 1.486436280241436, + "offset": [ + -904.5042545771234, + 205.34759825581054 + ] + } + }, + "version": 0.4 +} \ No newline at end of file diff --git a/x-flux-comfyui/workflows/ip_adapter_multiple_inputs_workflow.json b/x-flux-comfyui/workflows/ip_adapter_multiple_inputs_workflow.json new file mode 100644 index 0000000000000000000000000000000000000000..5bf212c6bb387c0d40839a931ca11553a1a3ccce --- /dev/null +++ b/x-flux-comfyui/workflows/ip_adapter_multiple_inputs_workflow.json @@ -0,0 +1,627 @@ +{ + "last_node_id": 49, + "last_link_id": 102, + "nodes": [ + { + "id": 3, + "type": "XlabsSampler", + "pos": [ + 866, + -13 + ], + "size": { + "0": 344.2750244140625, + "1": 591.3247680664062 + }, + "flags": {}, + "order": 10, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 99, + "slot_index": 0 + }, + { + "name": "conditioning", + "type": "CONDITIONING", + "link": 18 + }, + { + "name": "neg_conditioning", + "type": "CONDITIONING", + "link": 101 + }, + { + "name": "latent_image", + "type": "LATENT", + "link": 75 + }, + { + "name": "controlnet_condition", + "type": "ControlNetCondition", + "link": null + } + ], + "outputs": [ + { + "name": "latent", + "type": "LATENT", + "links": [ + 6 + ], + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "XlabsSampler" + }, + "widgets_values": [ + 4, + "fixed", + 20, + 1, + 3.5, + 0, + 1 + ] + }, + { + "id": 19, + "type": "CLIPTextEncodeFlux", + "pos": [ + 274, + 217 + ], + "size": { + "0": 400, + "1": 200 + }, + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 27, + "slot_index": 0 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 101 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncodeFlux" + }, + "widgets_values": [ + "", + "", + 4 + ] + }, + { + "id": 6, + "type": "EmptyLatentImage", + "pos": [ + 351, + 459 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 0, + "mode": 0, + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 75 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "EmptyLatentImage" + }, + "widgets_values": [ + 1024, + 1024, + 1 + ] + }, + { + "id": 5, + "type": "CLIPTextEncodeFlux", + "pos": [ + 279, + -11 + ], + "size": { + "0": 400, + "1": 200 + }, + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 2, + "slot_index": 0 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 18 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncodeFlux" + }, + "widgets_values": [ + "holding sign with glowing green text \"X-LABS IP Adapter\"", + "holding sign with glowing green text \"X-LABS IP Adapter\"", + 4 + ] + }, + { + "id": 10, + "type": "UNETLoader", + "pos": [ + -243, + -45 + ], + "size": { + "0": 315, + "1": 82 + }, + "flags": {}, + "order": 1, + "mode": 0, + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 100 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "UNETLoader" + }, + "widgets_values": [ + "flux1-dev-fp8.safetensors", + "fp8_e4m3fn" + ] + }, + { + "id": 48, + "type": "ApplyFluxIPAdapter", + "pos": [ + 391, + -252 + ], + "size": { + "0": 315, + "1": 98 + }, + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 100 + }, + { + "name": "ip_adapter_flux", + "type": "IP_ADAPTER_FLUX", + "link": 96 + }, + { + "name": "image", + "type": "IMAGE", + "link": 97 + } + ], + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 99 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "ApplyFluxIPAdapter" + }, + "widgets_values": [ + 0.6 + ] + }, + { + "id": 4, + "type": "DualCLIPLoader", + "pos": [ + -237, + 244 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 2, + "mode": 0, + "outputs": [ + { + "name": "CLIP", + "type": "CLIP", + "links": [ + 2, + 27 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "DualCLIPLoader" + }, + "widgets_values": [ + "flux/t5xxl_fp16.safetensors", + "flux/clip_l.safetensors", + "flux" + ] + }, + { + "id": 32, + "type": "LoadFluxIPAdapter", + "pos": [ + -236, + 81 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 3, + "mode": 0, + "outputs": [ + { + "name": "ipadapterFlux", + "type": "IP_ADAPTER_FLUX", + "links": [ + 96 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "LoadFluxIPAdapter" + }, + "widgets_values": [ + "80000.safetensors", + "flux/clip_vision_l.safetensors", + "CPU" + ] + }, + { + "id": 49, + "type": "PreviewImage", + "pos": [ + 112, + -414 + ], + "size": { + "0": 210, + "1": 26 + }, + "flags": {}, + "order": 9, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 102, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PreviewImage" + } + }, + { + "id": 8, + "type": "VAELoader", + "pos": [ + 878, + -162 + ], + "size": { + "0": 315, + "1": 58 + }, + "flags": {}, + "order": 4, + "mode": 0, + "outputs": [ + { + "name": "VAE", + "type": "VAE", + "links": [ + 59 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAELoader" + }, + "widgets_values": [ + "ae.safetensors" + ] + }, + { + "id": 36, + "type": "PreviewImage", + "pos": [ + 1399, + -152 + ], + "size": { + "0": 865.8053588867188, + "1": 863.5560913085938 + }, + "flags": {}, + "order": 12, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 76, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PreviewImage" + } + }, + { + "id": 7, + "type": "VAEDecode", + "pos": [ + 1418, + -235 + ], + "size": { + "0": 210, + "1": 46 + }, + "flags": {}, + "order": 11, + "mode": 0, + "inputs": [ + { + "name": "samples", + "type": "LATENT", + "link": 6, + "slot_index": 0 + }, + { + "name": "vae", + "type": "VAE", + "link": 59 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 76 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAEDecode" + } + }, + { + "id": 41, + "type": "CR Load Image List", + "pos": [ + -229, + -254 + ], + "size": { + "0": 315, + "1": 150 + }, + "flags": {}, + "order": 5, + "mode": 0, + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 97, + 102 + ], + "shape": 6, + "slot_index": 0 + }, + { + "name": "show_help", + "type": "STRING", + "links": [], + "shape": 3, + "slot_index": 1 + } + ], + "properties": { + "Node name for S&R": "CR Load Image List" + }, + "widgets_values": [ + "pasted", + 8, + 2, + "/ComfyUI/input/" + ] + } + ], + "links": [ + [ + 2, + 4, + 0, + 5, + 0, + "CLIP" + ], + [ + 6, + 3, + 0, + 7, + 0, + "LATENT" + ], + [ + 18, + 5, + 0, + 3, + 1, + "CONDITIONING" + ], + [ + 27, + 4, + 0, + 19, + 0, + "CLIP" + ], + [ + 59, + 8, + 0, + 7, + 1, + "VAE" + ], + [ + 75, + 6, + 0, + 3, + 3, + "LATENT" + ], + [ + 76, + 7, + 0, + 36, + 0, + "IMAGE" + ], + [ + 96, + 32, + 0, + 48, + 1, + "IP_ADAPTER_FLUX" + ], + [ + 97, + 41, + 0, + 48, + 2, + "IMAGE" + ], + [ + 99, + 48, + 0, + 3, + 0, + "MODEL" + ], + [ + 100, + 10, + 0, + 48, + 0, + "MODEL" + ], + [ + 101, + 19, + 0, + 3, + 2, + "CONDITIONING" + ], + [ + 102, + 41, + 0, + 49, + 0, + "IMAGE" + ] + ], + "groups": [], + "config": {}, + "extra": { + "ds": { + "scale": 0.5445000000000003, + "offset": [ + 263.58869460950524, + 511.3285537569527 + ] + } + }, + "version": 0.4 +} \ No newline at end of file diff --git a/x-flux-comfyui/workflows/ip_adapter_workflow.json b/x-flux-comfyui/workflows/ip_adapter_workflow.json new file mode 100644 index 0000000000000000000000000000000000000000..e250abd9808f8c21982ddce906f933429236e5cf --- /dev/null +++ b/x-flux-comfyui/workflows/ip_adapter_workflow.json @@ -0,0 +1,728 @@ +{ + "last_node_id": 36, + "last_link_id": 76, + "nodes": [ + { + "id": 6, + "type": "EmptyLatentImage", + "pos": [ + 553, + 475 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 0, + "mode": 0, + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 75 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "EmptyLatentImage" + }, + "widgets_values": [ + 1024, + 1024, + 1 + ] + }, + { + "id": 19, + "type": "CLIPTextEncodeFlux", + "pos": [ + 142, + 288 + ], + "size": { + "0": 400, + "1": 200 + }, + "flags": {}, + "order": 10, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 27, + "slot_index": 0 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 26 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncodeFlux" + }, + "widgets_values": [ + "", + "", + 4 + ] + }, + { + "id": 8, + "type": "VAELoader", + "pos": [ + 1048, + 347 + ], + "size": { + "0": 315, + "1": 58 + }, + "flags": {}, + "order": 1, + "mode": 0, + "outputs": [ + { + "name": "VAE", + "type": "VAE", + "links": [ + 59 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAELoader" + }, + "widgets_values": [ + "ae.safetensors" + ] + }, + { + "id": 35, + "type": "FluxLoraLoader", + "pos": [ + 1020, + -158 + ], + "size": { + "0": 315, + "1": 82 + }, + "flags": {}, + "order": 2, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": null + } + ], + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "FluxLoraLoader" + }, + "widgets_values": [ + "anime_lora.safetensors", + 1 + ] + }, + { + "id": 10, + "type": "UNETLoader", + "pos": [ + 149, + 589 + ], + "size": { + "0": 315, + "1": 82 + }, + "flags": {}, + "order": 3, + "mode": 0, + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 61 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "UNETLoader" + }, + "widgets_values": [ + "flux1-dev-fp8.safetensors", + "fp8_e4m3fn" + ] + }, + { + "id": 3, + "type": "XlabsSampler", + "pos": [ + 887, + 57 + ], + "size": { + "0": 342.5999755859375, + "1": 258 + }, + "flags": {}, + "order": 12, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 62, + "slot_index": 0 + }, + { + "name": "conditioning", + "type": "CONDITIONING", + "link": 18 + }, + { + "name": "neg_conditioning", + "type": "CONDITIONING", + "link": 26 + }, + { + "name": "latent_image", + "type": "LATENT", + "link": 75 + }, + { + "name": "controlnet_condition", + "type": "ControlNetCondition", + "link": null + } + ], + "outputs": [ + { + "name": "latent", + "type": "LATENT", + "links": [ + 6 + ], + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "XlabsSampler" + }, + "widgets_values": [ + 4, + "fixed", + 50, + 1, + 3.5, + 0 + ] + }, + { + "id": 32, + "type": "LoadFluxIPAdapter", + "pos": [ + 313, + 147 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 4, + "mode": 0, + "outputs": [ + { + "name": "ipadapterFlux", + "type": "IP_ADAPTER_FLUX", + "links": [ + 65 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "LoadFluxIPAdapter" + }, + "widgets_values": [ + "flux-ip-adapter.safetensors", + "model.safetensors", + "CPU" + ] + }, + { + "id": 5, + "type": "CLIPTextEncodeFlux", + "pos": [ + 428, + -169 + ], + "size": { + "0": 400, + "1": 200 + }, + "flags": {}, + "order": 9, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 2, + "slot_index": 0 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 18 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncodeFlux" + }, + "widgets_values": [ + "holding sign with glowing green text \"X-LABS IP Adapter\"", + "holding sign with glowing green text \"X-LABS IP Adapter\"", + 4 + ] + }, + { + "id": 27, + "type": "ApplyFluxIPAdapter", + "pos": [ + 642, + 248 + ], + "size": { + "0": 210, + "1": 98 + }, + "flags": {}, + "order": 11, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 61, + "slot_index": 0 + }, + { + "name": "ip_adapter_flux", + "type": "IP_ADAPTER_FLUX", + "link": 65 + }, + { + "name": "image", + "type": "IMAGE", + "link": 73, + "slot_index": 2 + } + ], + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 62 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "ApplyFluxIPAdapter" + }, + "widgets_values": [ + 0.92 + ] + }, + { + "id": 29, + "type": "ImageCrop", + "pos": [ + -54, + 53 + ], + "size": { + "0": 315, + "1": 130 + }, + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "image", + "type": "IMAGE", + "link": 55, + "slot_index": 0 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "ImageCrop" + }, + "widgets_values": [ + 1024, + 512, + 4, + 4 + ] + }, + { + "id": 33, + "type": "ImageScale", + "pos": [ + -80, + -148 + ], + "size": { + "0": 315, + "1": 130 + }, + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "image", + "type": "IMAGE", + "link": 72, + "slot_index": 0 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 73 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "ImageScale" + }, + "widgets_values": [ + "nearest-exact", + 1024, + 1024, + "disabled" + ] + }, + { + "id": 16, + "type": "LoadImage", + "pos": [ + -446, + -191 + ], + "size": { + "0": 315, + "1": 314 + }, + "flags": {}, + "order": 5, + "mode": 0, + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 55, + 72 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "statue.jpg", + "image" + ] + }, + { + "id": 4, + "type": "DualCLIPLoader", + "pos": [ + -275, + 322 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 6, + "mode": 0, + "outputs": [ + { + "name": "CLIP", + "type": "CLIP", + "links": [ + 2, + 27 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "DualCLIPLoader" + }, + "widgets_values": [ + "t5xxl_fp16.safetensors", + "clip_l.safetensors", + "flux" + ] + }, + { + "id": 36, + "type": "PreviewImage", + "pos": [ + 1663, + -228 + ], + "size": { + "0": 865.8053588867188, + "1": 863.5560913085938 + }, + "flags": {}, + "order": 14, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 76, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PreviewImage" + } + }, + { + "id": 7, + "type": "VAEDecode", + "pos": [ + 1346, + -128 + ], + "size": { + "0": 210, + "1": 46 + }, + "flags": {}, + "order": 13, + "mode": 0, + "inputs": [ + { + "name": "samples", + "type": "LATENT", + "link": 6, + "slot_index": 0 + }, + { + "name": "vae", + "type": "VAE", + "link": 59 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 76 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAEDecode" + } + } + ], + "links": [ + [ + 2, + 4, + 0, + 5, + 0, + "CLIP" + ], + [ + 6, + 3, + 0, + 7, + 0, + "LATENT" + ], + [ + 18, + 5, + 0, + 3, + 1, + "CONDITIONING" + ], + [ + 26, + 19, + 0, + 3, + 2, + "CONDITIONING" + ], + [ + 27, + 4, + 0, + 19, + 0, + "CLIP" + ], + [ + 55, + 16, + 0, + 29, + 0, + "IMAGE" + ], + [ + 59, + 8, + 0, + 7, + 1, + "VAE" + ], + [ + 61, + 10, + 0, + 27, + 0, + "MODEL" + ], + [ + 62, + 27, + 0, + 3, + 0, + "MODEL" + ], + [ + 65, + 32, + 0, + 27, + 1, + "IP_ADAPTER_FLUX" + ], + [ + 72, + 16, + 0, + 33, + 0, + "IMAGE" + ], + [ + 73, + 33, + 0, + 27, + 2, + "IMAGE" + ], + [ + 75, + 6, + 0, + 3, + 3, + "LATENT" + ], + [ + 76, + 7, + 0, + 36, + 0, + "IMAGE" + ] + ], + "groups": [], + "config": {}, + "extra": { + "ds": { + "scale": 0.6727499949325612, + "offset": [ + 454.9047202912717, + 602.847204664566 + ] + } + }, + "version": 0.4 +} diff --git a/x-flux-comfyui/workflows/lora_workflow.json b/x-flux-comfyui/workflows/lora_workflow.json new file mode 100644 index 0000000000000000000000000000000000000000..23c7b2d813b4c9c3e02d5aa09eb1018994b6889a --- /dev/null +++ b/x-flux-comfyui/workflows/lora_workflow.json @@ -0,0 +1,489 @@ +{ + "last_node_id": 23, + "last_link_id": 37, + "nodes": [ + { + "id": 8, + "type": "VAELoader", + "pos": [ + 1102, + 48 + ], + "size": { + "0": 315, + "1": 58 + }, + "flags": {}, + "order": 0, + "mode": 0, + "outputs": [ + { + "name": "VAE", + "type": "VAE", + "links": [ + 7 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAELoader" + }, + "widgets_values": [ + "ae.safetensors" + ] + }, + { + "id": 7, + "type": "VAEDecode", + "pos": [ + 1371, + 152 + ], + "size": { + "0": 210, + "1": 46 + }, + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "samples", + "type": "LATENT", + "link": 6, + "slot_index": 0 + }, + { + "name": "vae", + "type": "VAE", + "link": 7 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 31 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAEDecode" + } + }, + { + "id": 19, + "type": "CLIPTextEncodeFlux", + "pos": [ + -17, + 116 + ], + "size": { + "0": 400, + "1": 200 + }, + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 27, + "slot_index": 0 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 26 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncodeFlux" + }, + "widgets_values": [ + "", + "", + 4 + ] + }, + { + "id": 6, + "type": "EmptyLatentImage", + "pos": [ + 626, + 428 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 1, + "mode": 0, + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 5 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "EmptyLatentImage" + }, + "widgets_values": [ + 1024, + 1024, + 1 + ] + }, + { + "id": 10, + "type": "UNETLoader", + "pos": [ + 209, + 387 + ], + "size": { + "0": 315, + "1": 82 + }, + "flags": {}, + "order": 2, + "mode": 0, + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 36 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "UNETLoader" + }, + "widgets_values": [ + "flux1-dev-fp8.safetensors", + "fp8_e4m3fn" + ] + }, + { + "id": 3, + "type": "XlabsSampler", + "pos": [ + 1013, + 169 + ], + "size": { + "0": 342.5999755859375, + "1": 234 + }, + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 37, + "slot_index": 0 + }, + { + "name": "conditioning", + "type": "CONDITIONING", + "link": 18 + }, + { + "name": "neg_conditioning", + "type": "CONDITIONING", + "link": 26 + }, + { + "name": "latent_image", + "type": "LATENT", + "link": 5 + }, + { + "name": "controlnet_condition", + "type": "ControlNetCondition", + "link": null + } + ], + "outputs": [ + { + "name": "latent", + "type": "LATENT", + "links": [ + 6 + ], + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "XlabsSampler" + }, + "widgets_values": [ + 257762932021983, + "fixed", + 25, + 1, + 3.5 + ] + }, + { + "id": 4, + "type": "DualCLIPLoader", + "pos": [ + -176, + -93 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 3, + "mode": 0, + "outputs": [ + { + "name": "CLIP", + "type": "CLIP", + "links": [ + 2, + 27 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "DualCLIPLoader" + }, + "widgets_values": [ + "clip_l.safetensors", + "t5xxl_fp16.safetensors", + "flux" + ] + }, + { + "id": 5, + "type": "CLIPTextEncodeFlux", + "pos": [ + 518, + -63 + ], + "size": { + "0": 400, + "1": 200 + }, + "flags": {}, + "order": 5, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 2, + "slot_index": 0 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 18 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "CLIPTextEncodeFlux" + }, + "widgets_values": [ + "furry in the city with text \"hello world\"", + "furry in the city with text \"hello world\"", + 3.5 + ] + }, + { + "id": 23, + "type": "FluxLoraLoader", + "pos": [ + 506, + 231 + ], + "size": { + "0": 315, + "1": 82 + }, + "flags": {}, + "order": 4, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 36 + } + ], + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 37 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "FluxLoraLoader" + }, + "widgets_values": [ + "furry_lora.safetensors", + 0.9 + ] + }, + { + "id": 21, + "type": "PreviewImage", + "pos": [ + 1612, + 128 + ], + "size": { + "0": 210, + "1": 246 + }, + "flags": {}, + "order": 9, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 31, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "PreviewImage" + } + } + ], + "links": [ + [ + 2, + 4, + 0, + 5, + 0, + "CLIP" + ], + [ + 5, + 6, + 0, + 3, + 3, + "LATENT" + ], + [ + 6, + 3, + 0, + 7, + 0, + "LATENT" + ], + [ + 7, + 8, + 0, + 7, + 1, + "VAE" + ], + [ + 18, + 5, + 0, + 3, + 1, + "CONDITIONING" + ], + [ + 26, + 19, + 0, + 3, + 2, + "CONDITIONING" + ], + [ + 27, + 4, + 0, + 19, + 0, + "CLIP" + ], + [ + 31, + 7, + 0, + 21, + 0, + "IMAGE" + ], + [ + 36, + 10, + 0, + 23, + 0, + "MODEL" + ], + [ + 37, + 23, + 0, + 3, + 0, + "MODEL" + ] + ], + "groups": [], + "config": {}, + "extra": { + "ds": { + "scale": 2.1762913579014866, + "offset": [ + -1101.4302395366494, + -13.803910340891065 + ] + } + }, + "version": 0.4 +} \ No newline at end of file diff --git a/x-flux-comfyui/xflux/LICENSE b/x-flux-comfyui/xflux/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/x-flux-comfyui/xflux/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/x-flux-comfyui/xflux/README.md b/x-flux-comfyui/xflux/README.md new file mode 100644 index 0000000000000000000000000000000000000000..42a60d7dcacd1072fd04498d7459b150b4df8c27 --- /dev/null +++ b/x-flux-comfyui/xflux/README.md @@ -0,0 +1,234 @@ +![FLUX Finetuning scripts](./assets/readme/dark/header-rev1.png) + +This repository provides training scripts for [Flux model](https://github.com/black-forest-labs/flux) by Black Forest Labs.
+[XLabs AI](https://github.com/XLabs-AI) team is happy to publish fune-tuning Flux scripts, including: + +- **LoRA** 🔥 +- **ControlNet** 🔥 + +# Training + +We trained LoRA and ControlNet models using [DeepSpeed](https://github.com/microsoft/DeepSpeed)!
+It's available for 1024x1024 resolution! + +## Models + +We trained **Canny ControlNet**, **Depth ControlNet**, **HED ControlNet** and **LoRA** checkpoints for [`FLUX.1 [dev]`](https://github.com/black-forest-labs/flux)
+You can download them on HuggingFace: + +- [flux-controlnet-collections](https://huggingface.co/XLabs-AI/flux-controlnet-collections) +- [flux-controlnet-canny](https://huggingface.co/XLabs-AI/flux-controlnet-canny) +- [flux-RealismLora](https://huggingface.co/XLabs-AI/flux-RealismLora) +- [flux-lora-collections](https://huggingface.co/XLabs-AI/flux-lora-collection) +- [flux-furry-lora](https://huggingface.co/XLabs-AI/flux-furry-lora) + +### LoRA + +```bash +accelerate launch train_flux_lora_deepspeed.py --config "train_configs/test_lora.yaml" +``` + +### ControlNet + +```bash +accelerate launch train_flux_deepspeed_controlnet.py --config "train_configs/test_canny_controlnet.yaml" +``` + +## Training Dataset + +Dataset has the following format for the training process: + +```text +├── images/ +│ ├── 1.png +│ ├── 1.json +│ ├── 2.png +│ ├── 2.json +│ ├── ... +``` + +### Example `images/*.json` file + +A `.json` file contains "caption" field with a text prompt. + +```json +{ + "caption": "A figure stands in a misty landscape, wearing a mask with antlers and dark, embellished attire, exuding mystery and otherworldlines" +} +``` + +## Inference + +To test our checkpoints, use commands presented below. + +### LoRA +![Example Picture 1](./assets/readme/examples/picture-5-rev1.png) +prompt: "A girl in a suit covered with bold tattoos and holding a vest pistol, beautiful woman, 25 years old, cool, future fantasy, turquoise & light orange ping curl hair" +![Example Picture 2](./assets/readme/examples/picture-6-rev1.png) +prompt: "A handsome man in a suit, 25 years old, cool, futuristic" + +```bash +python3 main.py \ + --prompt "Female furry Pixie with text 'hello world'" \ + --lora_repo_id XLabs-AI/flux-furry-lora --lora_name furry_lora.safetensors --device cuda --offload --use_lora \ + --model_type flux-dev-fp8 --width 1024 --height 1024 \ + --timestep_to_start_cfg 1 --num_steps 25 --true_gs 3.5 --guidance 4 + +``` + +![Example Picture 1](./assets/readme/examples/furry4.png) + +```bash +python3 main.py \ +--prompt "A cute corgi lives in a house made out of sushi, anime" \ +--lora_repo_id XLabs-AI/flux-lora-collection --lora_name anime_lora.safetensors \ +--device cuda --offload --use_lora --model_type flux-dev-fp8 --width 1024 --height 1024 + +``` +![Example Picture 3](./assets/readme/examples/result_14.png) + + +```bash +python3 main.py \ + --use_lora --lora_weight 0.7 \ + --width 1024 --height 768 \ + --lora_repo_id XLabs-AI/flux-lora-collection --lora_name realism_lora.safetensors \ + --guidance 4 \ + --prompt "contrast play photography of a black female wearing white suit and albino asian geisha female wearing black suit, solid background, avant garde, high fashion" +``` +![Example Picture 3](./assets/readme/examples/picture-7-rev1.png) + +## Canny ControlNet +```bash +python3 main.py \ + --prompt "a viking man with white hair looking, cinematic, MM full HD" \ + --image input_image_canny.jpg \ + --control_type canny \ + --repo_id XLabs-AI/flux-controlnet-collections --name flux-canny-controlnet.safetensors --device cuda --use_controlnet \ + --model_type flux-dev --width 768 --height 768 \ + --timestep_to_start_cfg 1 --num_steps 25 --true_gs 3.5 --guidance 4 + +``` +![Example Picture 1](./assets/readme/examples/canny_example_1.png?raw=true) + +## Depth ControlNet +```bash +python3 main.py \ + --prompt "Photo of the bold man with beard and laptop, full hd, cinematic photo" \ + --image input_image_depth1.jpg \ + --control_type depth \ + --repo_id XLabs-AI/flux-controlnet-collections --name flux-depth-controlnet.safetensors --device cuda --use_controlnet \ + --model_type flux-dev --width 1024 --height 1024 \ + --timestep_to_start_cfg 1 --num_steps 25 --true_gs 3.5 --guidance 4 + +``` +![Example Picture 2](./assets/readme/examples/depth_example_1.png?raw=true) + +```bash +python3 main.py \ + --prompt "photo of handsome fluffy black dog standing on a forest path, full hd, cinematic photo" \ + --image input_image_depth2.jpg \ + --control_type depth \ + --repo_id XLabs-AI/flux-controlnet-collections --name flux-depth-controlnet.safetensors --device cuda --use_controlnet \ + --model_type flux-dev --width 1024 --height 1024 \ + --timestep_to_start_cfg 1 --num_steps 25 --true_gs 3.5 --guidance 4 + +``` +![Example Picture 2](./assets/readme/examples/depth_example_2.png?raw=true) + +```bash +python3 main.py \ + --prompt "Photo of japanese village with houses and sakura, full hd, cinematic photo" \ + --image input_image_depth3.webp \ + --control_type depth \ + --repo_id XLabs-AI/flux-controlnet-collections --name flux-depth-controlnet.safetensors --device cuda --use_controlnet \ + --model_type flux-dev --width 1024 --height 1024 \ + --timestep_to_start_cfg 1 --num_steps 25 --true_gs 3.5 --guidance 4 + +``` +![Example Picture 2](./assets/readme/examples/depth_example_3.png?raw=true) + + +## HED ControlNet +```bash +python3 main.py \ + --prompt "2d art of a sitting african rich woman, full hd, cinematic photo" \ + --image input_image_hed1.jpg \ + --control_type hed \ + --repo_id XLabs-AI/flux-controlnet-collections --name flux-hed-controlnet.safetensors --device cuda --use_controlnet \ + --model_type flux-dev --width 768 --height 768 \ + --timestep_to_start_cfg 1 --num_steps 25 --true_gs 3.5 --guidance 4 + +``` +![Example Picture 2](./assets/readme/examples/hed_example_1.png?raw=true) + +```bash +python3 main.py \ + --prompt "anime ghibli style art of a running happy white dog, full hd" \ + --image input_image_hed2.jpg \ + --control_type hed \ + --repo_id XLabs-AI/flux-controlnet-collections --name flux-hed-controlnet.safetensors --device cuda --use_controlnet \ + --model_type flux-dev --width 768 --height 768 \ + --timestep_to_start_cfg 1 --num_steps 25 --true_gs 3.5 --guidance 4 + +``` +![Example Picture 2](./assets/readme/examples/hed_example_2.png?raw=true) + +## Low memory mode + +Use LoRA and Controlnet FP8 version based on [Flux-dev-F8](https://huggingface.co/XLabs-AI/flux-dev-fp8) with `--offload` setting to achieve lower VRAM usage (22 GB) and `--name flux-dev-fp8`: +```bash +python3 main.py \ + --offload --name flux-dev-fp8 \ + --lora_repo_id XLabs-AI/flux-lora-collection --lora_name realism_lora.safetensors \ + --guidance 4 \ + --prompt "A handsome girl in a suit covered with bold tattoos and holding a pistol. Animatrix illustration style, fantasy style, natural photo cinematic" +``` +![Example Picture 0](./assets/readme/examples/picture-0-rev1.png) + +## Requirements + +Install our dependencies by running the following command: + +```bash +pip3 install -r requirements.txt +``` + +## Accelerate Configuration Example + +```yaml +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 2 + gradient_clipping: 1.0 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false + +``` +## Models Licence + +Our models fall under the [FLUX.1 [dev] Non-Commercial License](https://github.com/black-forest-labs/flux/blob/main/model_licenses/LICENSE-FLUX1-dev)
Our training and infer scripts under the Apache 2 License + +## Near Updates + +We are working on releasing new ControlNet weight models for Flux: **OpenPose**, **Depth** and more!
+Stay tuned with [XLabs AI](https://github.com/XLabs-AI) to see **IP-Adapters** for Flux. + +![Follow Our Updates](./assets/readme/dark/follow-cta-rev2.png) diff --git a/x-flux-comfyui/xflux/src/flux/__init__.py b/x-flux-comfyui/xflux/src/flux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..43c365a49d6980e88acba10ef3069f110a59644a --- /dev/null +++ b/x-flux-comfyui/xflux/src/flux/__init__.py @@ -0,0 +1,11 @@ +try: + from ._version import version as __version__ # type: ignore + from ._version import version_tuple +except ImportError: + __version__ = "unknown (no version information available)" + version_tuple = (0, 0, "unknown", "noinfo") + +from pathlib import Path + +PACKAGE = __package__.replace("_", "-") +PACKAGE_ROOT = Path(__file__).parent diff --git a/x-flux-comfyui/xflux/src/flux/annotator/util.py b/x-flux-comfyui/xflux/src/flux/annotator/util.py new file mode 100644 index 0000000000000000000000000000000000000000..90831643d19cc1b9b0940df3d4fd4d846ba74a05 --- /dev/null +++ b/x-flux-comfyui/xflux/src/flux/annotator/util.py @@ -0,0 +1,38 @@ +import numpy as np +import cv2 +import os + + +annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') + + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + + +def resize_image(input_image, resolution): + H, W, C = input_image.shape + H = float(H) + W = float(W) + k = float(resolution) / min(H, W) + H *= k + W *= k + H = int(np.round(H / 64.0)) * 64 + W = int(np.round(W / 64.0)) * 64 + img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) + return img diff --git a/x-flux-comfyui/xflux/src/flux/controlnet.py b/x-flux-comfyui/xflux/src/flux/controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c5a04cc0234b2b726a550cbe62d027943f6bbcbb --- /dev/null +++ b/x-flux-comfyui/xflux/src/flux/controlnet.py @@ -0,0 +1,222 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor, nn +from einops import rearrange + +from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, + MLPEmbedder, SingleStreamBlock, + timestep_embedding) + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class ControlNetFlux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + _supports_gradient_checkpointing = True + + def __init__(self, params: FluxParams, controlnet_depth=2): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(controlnet_depth) + ] + ) + + # add ControlNet blocks + self.controlnet_blocks = nn.ModuleList([]) + for _ in range(controlnet_depth): + controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks.append(controlnet_block) + self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.gradient_checkpointing = False + self.input_hint_block = nn.Sequential( + nn.Conv2d(3, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + zero_module(nn.Conv2d(16, 16, 3, padding=1)) + ) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + + @property + def attn_processors(self): + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + controlnet_cond: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + controlnet_cond = self.input_hint_block(controlnet_cond) + controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_cond = self.pos_embed_input(controlnet_cond) + img = img + controlnet_cond + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + block_res_samples = () + + for block in self.double_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + img, + txt, + vec, + pe, + ) + else: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + block_res_samples = block_res_samples + (img,) + + controlnet_block_res_samples = () + for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): + block_res_sample = controlnet_block(block_res_sample) + controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) + + return controlnet_block_res_samples diff --git a/x-flux-comfyui/xflux/src/flux/math.py b/x-flux-comfyui/xflux/src/flux/math.py new file mode 100644 index 0000000000000000000000000000000000000000..0156bb6a205dec340e029f0c87cf70ae8709ae12 --- /dev/null +++ b/x-flux-comfyui/xflux/src/flux/math.py @@ -0,0 +1,30 @@ +import torch +from einops import rearrange +from torch import Tensor + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: + q, k = apply_rope(q, k, pe) + + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "B H L D -> B L (H D)") + + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) diff --git a/x-flux-comfyui/xflux/src/flux/model.py b/x-flux-comfyui/xflux/src/flux/model.py new file mode 100644 index 0000000000000000000000000000000000000000..c40194201a89f90784f5622575e9a2e6a5a42305 --- /dev/null +++ b/x-flux-comfyui/xflux/src/flux/model.py @@ -0,0 +1,217 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor, nn +from einops import rearrange + +from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, + MLPEmbedder, SingleStreamBlock, + timestep_embedding) +from typing import Dict, List, Any + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + + +class Flux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + _supports_gradient_checkpointing = True + + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + @property + def attn_processors(self): + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + block_controlnet_hidden_states=None, + guidance: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + if block_controlnet_hidden_states is not None: + controlnet_depth = len(block_controlnet_hidden_states) + for index_block, block in enumerate(self.double_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + img, + txt, + vec, + pe, + ) + else: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + # controlnet residual + if block_controlnet_hidden_states is not None: + img = img + block_controlnet_hidden_states[index_block % 2] + + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if torch.is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + img, + vec, + pe, + ) + else: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img diff --git a/x-flux-comfyui/xflux/src/flux/modules/autoencoder.py b/x-flux-comfyui/xflux/src/flux/modules/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..75159f711f65f064107a1a1b9be6f09fc9872028 --- /dev/null +++ b/x-flux-comfyui/xflux/src/flux/modules/autoencoder.py @@ -0,0 +1,312 @@ +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + ch: int + out_ch: int + ch_mult: list[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) diff --git a/x-flux-comfyui/xflux/src/flux/modules/conditioner.py b/x-flux-comfyui/xflux/src/flux/modules/conditioner.py new file mode 100644 index 0000000000000000000000000000000000000000..7cdd881878ace848745da7d723c60f03392916ab --- /dev/null +++ b/x-flux-comfyui/xflux/src/flux/modules/conditioner.py @@ -0,0 +1,38 @@ +from torch import Tensor, nn +from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel, + T5Tokenizer) + + +class HFEmbedder(nn.Module): + def __init__(self, version: str, max_length: int, **hf_kwargs): + super().__init__() + self.is_clip = version.startswith("openai") + self.max_length = max_length + self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" + + if self.is_clip: + self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length) + self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs) + else: + self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length) + self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs) + + self.hf_module = self.hf_module.eval().requires_grad_(False) + + def forward(self, text: list[str]) -> Tensor: + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + + outputs = self.hf_module( + input_ids=batch_encoding["input_ids"].to(self.hf_module.device), + attention_mask=None, + output_hidden_states=False, + ) + return outputs[self.output_key] diff --git a/x-flux-comfyui/xflux/src/flux/modules/layers.py b/x-flux-comfyui/xflux/src/flux/modules/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..8aadac04336cfde5e23460eebf1ff7d2a4626b98 --- /dev/null +++ b/x-flux-comfyui/xflux/src/flux/modules/layers.py @@ -0,0 +1,358 @@ +import math +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn + +from ..math import attention, rope + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + t.device + ) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms).to(dtype=x_dtype) * self.scale + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + +class LoRALinearLayer(nn.Module): + def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): + super().__init__() + + self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) + self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states.to(orig_dtype) + +class FLuxSelfAttnProcessor: + def __call__(self, attn, x, pe, **attention_kwargs): + print('2' * 30) + + qkv = attn.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = attn.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = attn.proj(x) + return x + +class LoraFluxAttnProcessor(nn.Module): + + def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): + super().__init__() + self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha) + self.lora_weight = lora_weight + + + def __call__(self, attn, x, pe, **attention_kwargs): + qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = attn.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = attn.proj(x) + self.proj_lora(x) * self.lora_weight + print('1' * 30) + print(x.norm(), (self.proj_lora(x) * self.lora_weight).norm(), 'norm') + return x + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + def forward(): + pass + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + +class DoubleStreamBlockLoraProcessor(nn.Module): + def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): + super().__init__() + self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha) + self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha) + self.lora_weight = lora_weight + + def forward(self, attn, img, txt, vec, pe, **attention_kwargs): + img_mod1, img_mod2 = attn.img_mod(vec) + txt_mod1, txt_mod2 = attn.txt_mod(vec) + + # prepare image for attention + img_modulated = attn.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = attn.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn1 = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img_mod1.gate * self.proj_lora1(img_attn) * self.lora_weight + img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt_mod1.gate * self.proj_lora2(txt_attn) * self.lora_weight + txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) + return img, txt + +class DoubleStreamBlockProcessor(nn.Module): + def __init__(self): + super().__init__() + def __call__(self, attn, img, txt, vec, pe, **attention_kwargs): + + img_mod1, img_mod2 = attn.img_mod(vec) + txt_mod1, txt_mod2 = attn.txt_mod(vec) + + # prepare image for attention + img_modulated = attn.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = attn.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = attn.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = attn.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn1 = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) + return img, txt + +class DoubleStreamBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): + super().__init__() + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + processor = DoubleStreamBlockProcessor() + self.set_processor(processor) + + def set_processor(self, processor) -> None: + self.processor = processor + + def get_processor(self): + return self.processor + + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: + return self.processor(self, img, txt, vec, pe) + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False) + + def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + mod, _ = self.modulation(vec) + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + + # compute attention + attn = attention(q, k, v, pe=pe) + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + return x + mod.gate * output + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x diff --git a/x-flux-comfyui/xflux/src/flux/sampling.py b/x-flux-comfyui/xflux/src/flux/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..33c49b2e8367db1b5996d5ca9ef6090d67ddba1e --- /dev/null +++ b/x-flux-comfyui/xflux/src/flux/sampling.py @@ -0,0 +1,248 @@ +import math +from typing import Callable + +import torch +from einops import rearrange, repeat +from torch import Tensor + +from .model import Flux +from .modules.conditioner import HFEmbedder + + +def get_noise( + num_samples: int, + height: int, + width: int, + device: torch.device, + dtype: torch.dtype, + seed: int, +): + return torch.randn( + num_samples, + 16, + # allow for packing + 2 * math.ceil(height / 16), + 2 * math.ceil(width / 16), + device=device, + dtype=dtype, + generator=torch.Generator(device=device).manual_seed(seed), + ) + + +def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]: + bs, c, h, w = img.shape + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img.shape[0] == 1 and bs > 1: + img = repeat(img, "1 ... -> bs ...", bs=bs) + + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + if isinstance(prompt, str): + prompt = [prompt] + txt = t5(prompt) + if txt.shape[0] == 1 and bs > 1: + txt = repeat(txt, "1 ... -> bs ...", bs=bs) + txt_ids = torch.zeros(bs, txt.shape[1], 3) + + vec = clip(prompt) + if vec.shape[0] == 1 and bs > 1: + vec = repeat(vec, "1 ... -> bs ...", bs=bs) + + return { + "img": img, + "img_ids": img_ids.to(img.device), + "txt": txt.to(img.device), + "txt_ids": txt_ids.to(img.device), + "vec": vec.to(img.device), + } + + +def time_shift(mu: float, sigma: float, t: Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function( + x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 +) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: Flux, + # model input + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + vec: Tensor, + neg_txt: Tensor, + neg_txt_ids: Tensor, + neg_vec: Tensor, + # sampling parameters + timesteps: list[float], + guidance: float = 4.0, + true_gs = 1, + timestep_to_start_cfg=0, + image2image_strength=None, + orig_image = None, +): + i = 0 + + #init_latents = rearrange(init_latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if image2image_strength is not None and orig_image is not None: + t_idx = int((1 - image2image_strength) * len(timesteps)) + t = timesteps[t_idx] + timesteps = timesteps[t_idx:] + img = t * img + (1.0 - t) * orig_image.to(img.dtype) + # this is ignored for schnell + if hasattr(model, "guidance_in"): + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + else: + guidance_vec = None + for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + ) + if i >= timestep_to_start_cfg: + neg_pred = model( + img=img, + img_ids=img_ids, + txt=neg_txt, + txt_ids=neg_txt_ids, + y=neg_vec, + timesteps=t_vec, + guidance=guidance_vec, + ) + pred = neg_pred + true_gs * (pred - neg_pred) + img = img + (t_prev - t_curr) * pred + i += 1 + return img + +def denoise_controlnet( + model: Flux, + controlnet:None, + # model input + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + vec: Tensor, + neg_txt: Tensor, + neg_txt_ids: Tensor, + neg_vec: Tensor, + controlnet_cond, + # sampling parameters + timesteps: list[float], + guidance: float = 4.0, + true_gs = 1, + controlnet_gs=0.7, + timestep_to_start_cfg=0, + image2image_strength=None, + orig_image = None, +): + i = 0 + + #init_latents = rearrange(init_latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if image2image_strength is not None and orig_image is not None: + t_idx = int((1 - image2image_strength) * len(timesteps)) + t = timesteps[t_idx] + timesteps = timesteps[t_idx:] + img = t * img + (1.0 - t) * orig_image.to(img.dtype) + # this is ignored for schnell + if hasattr(model, "guidance_in"): + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + else: + guidance_vec = None + for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + block_res_samples = controlnet( + img=img, + img_ids=img_ids, + controlnet_cond=controlnet_cond, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + ) + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + block_controlnet_hidden_states=[i * controlnet_gs for i in block_res_samples] + ) + if i >= timestep_to_start_cfg: + neg_block_res_samples = controlnet( + img=img, + img_ids=img_ids, + controlnet_cond=controlnet_cond, + txt=neg_txt, + txt_ids=neg_txt_ids, + y=neg_vec, + timesteps=t_vec, + guidance=guidance_vec, + ) + neg_pred = model( + img=img, + img_ids=img_ids, + txt=neg_txt, + txt_ids=neg_txt_ids, + y=neg_vec, + timesteps=t_vec, + guidance=guidance_vec, + block_controlnet_hidden_states=[i * controlnet_gs for i in neg_block_res_samples] + ) + pred = neg_pred + true_gs * (pred - neg_pred) + + img = img + (t_prev - t_curr) * pred + + i += 1 + return img + +def unpack(x: Tensor, height: int, width: int) -> Tensor: + return rearrange( + x, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(height / 16), + w=math.ceil(width / 16), + ph=2, + pw=2, + ) diff --git a/x-flux-comfyui/xflux/src/flux/util.py b/x-flux-comfyui/xflux/src/flux/util.py new file mode 100644 index 0000000000000000000000000000000000000000..3aa0769b1a5246d2648150837b44a0699b7710cf --- /dev/null +++ b/x-flux-comfyui/xflux/src/flux/util.py @@ -0,0 +1,350 @@ +import os +from dataclasses import dataclass + +import torch +import json +import cv2 +import numpy as np +from PIL import Image +from huggingface_hub import hf_hub_download +from safetensors import safe_open +from safetensors.torch import load_file as load_sft + +from .model import Flux, FluxParams +from .controlnet import ControlNetFlux +from .modules.autoencoder import AutoEncoder, AutoEncoderParams +from .modules.conditioner import HFEmbedder + + + +def load_safetensors(path): + tensors = {} + with safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + tensors[key] = f.get_tensor(key) + return tensors + +def get_lora_rank(checkpoint): + for k in checkpoint.keys(): + if k.endswith(".down.weight"): + return checkpoint[k].shape[0] + +def load_checkpoint(local_path, repo_id, name): + if local_path is not None: + if '.safetensors' in local_path: + print("Loading .safetensors checkpoint...") + checkpoint = load_safetensors(local_path) + else: + print("Loading checkpoint...") + checkpoint = torch.load(local_path, map_location='cpu') + elif repo_id is not None and name is not None: + print("Loading checkpoint from repo id...") + checkpoint = load_from_repo_id(repo_id, name) + else: + raise ValueError( + "LOADING ERROR: you must specify local_path or repo_id with name in HF to download" + ) + return checkpoint + + +def c_crop(image): + width, height = image.size + new_size = min(width, height) + left = (width - new_size) / 2 + top = (height - new_size) / 2 + right = (width + new_size) / 2 + bottom = (height + new_size) / 2 + return image.crop((left, top, right, bottom)) + + +class Annotator: + def __init__(self, name: str, device: str): + if name == "canny": + processor = CannyDetector() + elif name == "openpose": + processor = DWposeDetector(device) + elif name == "depth": + processor = MidasDetector() + elif name == "hed": + processor = HEDdetector() + elif name == "hough": + processor = MLSDdetector() + elif name == "tile": + processor = TileDetector() + self.name = name + self.processor = processor + + def __call__(self, image: Image, width: int, height: int): + image = c_crop(image) + image = image.resize((width, height)) + image = np.array(image) + if self.name == "canny": + result = self.processor(image, low_threshold=100, high_threshold=200) + elif self.name == "hough": + result = self.processor(image, thr_v=0.05, thr_d=5) + elif self.name == "depth": + result = self.processor(image) + result, _ = result + else: + result = self.processor(image) + + if result.ndim != 3: + result = result[:, :, None] + result = np.concatenate([result, result, result], axis=2) + return result + + +@dataclass +class ModelSpec: + params: FluxParams + ae_params: AutoEncoderParams + ckpt_path: str | None + ae_path: str | None + repo_id: str | None + repo_flow: str | None + repo_ae: str | None + repo_id_ae: str | None + + +configs = { + "flux-dev": ModelSpec( + repo_id="black-forest-labs/FLUX.1-dev", + repo_id_ae="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-dev.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-dev-fp8": ModelSpec( + repo_id="XLabs-AI/flux-dev-fp8", + repo_id_ae="black-forest-labs/FLUX.1-dev", + repo_flow="flux-dev-fp8.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV_FP8"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-schnell": ModelSpec( + repo_id="black-forest-labs/FLUX.1-schnell", + repo_id_ae="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-schnell.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_SCHNELL"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} + + +def print_load_warning(missing: list[str], unexpected: list[str]) -> None: + if len(missing) > 0 and len(unexpected) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + print("\n" + "-" * 79 + "\n") + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + elif len(missing) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + elif len(unexpected) > 0: + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + +def load_from_repo_id(repo_id, checkpoint_name): + ckpt_path = hf_hub_download(repo_id, checkpoint_name) + sd = load_sft(ckpt_path, device='cpu') + return sd + +def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True): + # Loading Flux + print("Init model") + ckpt_path = configs[name].ckpt_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) + + with torch.device("meta" if ckpt_path is not None else device): + model = Flux(configs[name].params).to(torch.bfloat16) + + if ckpt_path is not None: + print("Loading checkpoint") + # load_sft doesn't support torch.device + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return model + +def load_flow_model2(name: str, device: str | torch.device = "cuda", hf_download: bool = True): + # Loading Flux + print("Init model") + ckpt_path = configs[name].ckpt_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors")) + + with torch.device("meta" if ckpt_path is not None else device): + model = Flux(configs[name].params) + + if ckpt_path is not None: + print("Loading checkpoint") + # load_sft doesn't support torch.device + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return model + +def load_controlnet(name, device, transformer=None): + with torch.device(device): + controlnet = ControlNetFlux(configs[name].params) + if transformer is not None: + controlnet.load_state_dict(transformer.state_dict(), strict=False) + return controlnet + +def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: + # max length 64, 128, 256 and 512 should work (if your sequence is short enough) + return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device) + +def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: + return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device) + + +def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder: + ckpt_path = configs[name].ae_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_ae is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae) + + # Loading the autoencoder + print("Init AE") + with torch.device("meta" if ckpt_path is not None else device): + ae = AutoEncoder(configs[name].ae_params) + + if ckpt_path is not None: + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return ae + + +class WatermarkEmbedder: + def __init__(self, watermark): + self.watermark = watermark + self.num_bits = len(WATERMARK_BITS) + self.encoder = WatermarkEncoder() + self.encoder.set_watermark("bits", self.watermark) + + def __call__(self, image: torch.Tensor) -> torch.Tensor: + """ + Adds a predefined watermark to the input image + + Args: + image: ([N,] B, RGB, H, W) in range [-1, 1] + + Returns: + same as input but watermarked + """ + image = 0.5 * image + 0.5 + squeeze = len(image.shape) == 4 + if squeeze: + image = image[None, ...] + n = image.shape[0] + image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1] + # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] + # watermarking libary expects input as cv2 BGR format + for k in range(image_np.shape[0]): + image_np[k] = self.encoder.encode(image_np[k], "dwtDct") + image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to( + image.device + ) + image = torch.clamp(image / 255, min=0.0, max=1.0) + if squeeze: + image = image[0] + image = 2 * image - 1 + return image + + +# A fixed 48-bit message that was choosen at random +WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110 +# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 +WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] diff --git a/x-flux-comfyui/xflux/src/flux/xflux_pipeline.py b/x-flux-comfyui/xflux/src/flux/xflux_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..7964bb8e9628ce9e695a46601f0960376a95a432 --- /dev/null +++ b/x-flux-comfyui/xflux/src/flux/xflux_pipeline.py @@ -0,0 +1,152 @@ +from PIL import Image +import numpy as np +import torch + +from einops import rearrange + +from src.flux.modules.layers import DoubleStreamBlockLoraProcessor +from src.flux.sampling import denoise, denoise_controlnet, get_noise, get_schedule, prepare, unpack +from src.flux.util import (load_ae, load_clip, load_flow_model, load_t5, load_controlnet, + load_flow_model_quintized, Annotator, get_lora_rank, load_checkpoint) + + +class XFluxPipeline: + def __init__(self, model_type, device, offload: bool = False, seed: int = None): + self.device = torch.device(device) + self.offload = offload + self.seed = seed + self.model_type = model_type + + self.clip = load_clip(self.device) + self.t5 = load_t5(self.device, max_length=512) + self.ae = load_ae(model_type, device="cpu" if offload else self.device) + if "fp8" in model_type: + self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device) + else: + self.model = load_flow_model(model_type, device="cpu" if offload else self.device) + + self.hf_lora_collection = "XLabs-AI/flux-lora-collection" + self.lora_types_to_names = { + "realism": "lora.safetensors", + } + self.controlnet_loaded = False + + def set_lora(self, local_path: str = None, repo_id: str = None, + name: str = None, lora_weight: int = 0.7): + checkpoint = load_checkpoint(local_path, repo_id, name) + self.update_model_with_lora(checkpoint, lora_weight) + + def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7): + checkpoint = load_checkpoint( + None, self.hf_lora_collection, self.lora_types_to_names[lora_type] + ) + self.update_model_with_lora(checkpoint, lora_weight) + + def update_model_with_lora(self, checkpoint, lora_weight): + rank = get_lora_rank(checkpoint) + lora_attn_procs = {} + + for name, _ in self.model.attn_processors.items(): + lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank) + lora_state_dict = {} + for k in checkpoint.keys(): + if name in k: + lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight + lora_attn_procs[name].load_state_dict(lora_state_dict) + lora_attn_procs[name].to(self.device) + + self.model.set_attn_processor(lora_attn_procs) + + def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None): + self.model.to(self.device) + self.controlnet = load_controlnet(self.model_type, self.device).to(torch.bfloat16) + + checkpoint = load_checkpoint(local_path, repo_id, name) + self.controlnet.load_state_dict(checkpoint, strict=False) + + if control_type == "depth": + self.controlnet_gs = 0.9 + else: + self.controlnet_gs = 0.7 + self.annotator = Annotator(control_type, self.device) + self.controlnet_loaded = True + + def __call__(self, + prompt: str, + controlnet_image: Image = None, + width: int = 512, + height: int = 512, + guidance: float = 4, + num_steps: int = 50, + true_gs = 3, + neg_prompt: str = '', + timestep_to_start_cfg: int = 0, + ): + width = 16 * width // 16 + height = 16 * height // 16 + if self.controlnet_loaded: + controlnet_image = self.annotator(controlnet_image, width, height) + controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) + controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.device) + + return self.forward(prompt, width, height, guidance, num_steps, controlnet_image, + timestep_to_start_cfg=timestep_to_start_cfg, true_gs=true_gs, neg_prompt=neg_prompt) + + def forward(self, prompt, width, height, guidance, num_steps, controlnet_image=None, timestep_to_start_cfg=0, true_gs=3, neg_prompt=""): + x = get_noise( + 1, height, width, device=self.device, + dtype=torch.bfloat16, seed=self.seed + ) + timesteps = get_schedule( + num_steps, + (width // 8) * (height // 8) // (16 * 16), + shift=True, + ) + torch.manual_seed(self.seed) + with torch.no_grad(): + if self.offload: + self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device) + inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt) + neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt) + + if self.offload: + self.offload_model_to_cpu(self.t5, self.clip) + self.model = self.model.to(self.device) + if self.controlnet_loaded: + x = denoise_controlnet( + self.model, **inp_cond, controlnet=self.controlnet, + timesteps=timesteps, guidance=guidance, + controlnet_cond=controlnet_image, + timestep_to_start_cfg=timestep_to_start_cfg, + neg_txt=neg_inp_cond['txt'], + neg_txt_ids=neg_inp_cond['txt_ids'], + neg_vec=neg_inp_cond['vec'], + true_gs=true_gs, + controlnet_gs=self.controlnet_gs, + ) + else: + x = denoise(self.model, **inp_cond, timesteps=timesteps, guidance=guidance, + timestep_to_start_cfg=timestep_to_start_cfg, + neg_txt=neg_inp_cond['txt'], + neg_txt_ids=neg_inp_cond['txt_ids'], + neg_vec=neg_inp_cond['vec'], + true_gs=true_gs + ) + + if self.offload: + self.offload_model_to_cpu(self.model) + self.ae.decoder.to(x.device) + x = unpack(x.float(), height, width) + x = self.ae.decode(x) + self.offload_model_to_cpu(self.ae.decoder) + + x1 = x.clamp(-1, 1) + x1 = rearrange(x1[-1], "c h w -> h w c") + output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy()) + return output_img + + def offload_model_to_cpu(self, *models): + if not self.offload: return + for model in models: + model.cpu() + torch.cuda.empty_cache()