File size: 5,555 Bytes
0be8e8b
7a9de80
d011aaf
4502b8a
 
 
 
 
def0c86
7a9de80
6681fbe
8cc24d6
 
7a9de80
6681fbe
0be8e8b
8cc24d6
0b556fe
 
4502b8a
0b556fe
4502b8a
 
 
 
 
 
 
 
 
 
 
 
0b556fe
4502b8a
d011aaf
 
4502b8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b556fe
4502b8a
0b556fe
d011aaf
 
4502b8a
0b556fe
4502b8a
 
 
0b556fe
 
 
 
 
4502b8a
 
0b556fe
 
4502b8a
0b556fe
 
 
 
 
4502b8a
0b556fe
 
 
 
 
 
4502b8a
 
 
0b556fe
 
 
 
 
 
4502b8a
 
 
 
 
 
 
0b556fe
 
 
 
 
 
 
 
 
 
4502b8a
 
 
 
0b556fe
 
 
 
 
 
 
 
 
4502b8a
dfedc88
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import sys
import random
import torch
import numpy as np
from PIL import Image
import gradio as gr

# Check and add the ComfyUI repository path to sys.path
repo_path = './ComfyUI/totoro_extras'
print(f"Checking for repository path: {repo_path}")
if not os.path.exists(repo_path):
    raise FileNotFoundError(f"Repository path '{repo_path}' not found. Make sure the ComfyUI repository is cloned correctly.")
sys.path.append(repo_path)
print(f"Repository path added to sys.path: {repo_path}")

# Import nodes and custom modules
from nodes import NODE_CLASS_MAPPINGS
from totoro_extras import nodes_custom_sampler, nodes_flux

# Initialize necessary components from the nodes
CheckpointLoaderSimple = NODE_CLASS_MAPPINGS["CheckpointLoaderSimple"]()
LoraLoader = NODE_CLASS_MAPPINGS["LoraLoader"]()
FluxGuidance = nodes_flux.NODE_CLASS_MAPPINGS["FluxGuidance"]()
RandomNoise = nodes_custom_sampler.NODE_CLASS_MAPPINGS["RandomNoise"]()
BasicGuider = nodes_custom_sampler.NODE_CLASS_MAPPINGS["BasicGuider"]()
KSamplerSelect = nodes_custom_sampler.NODE_CLASS_MAPPINGS["KSamplerSelect"]()
BasicScheduler = nodes_custom_sampler.NODE_CLASS_MAPPINGS["BasicScheduler"]()
SamplerCustomAdvanced = nodes_custom_sampler.NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
VAELoader = NODE_CLASS_MAPPINGS["VAELoader"]()
VAEDecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
EmptyLatentImage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()

# Load checkpoints and models
with torch.inference_mode():
    checkpoint_path = "models/checkpoints/flux1-dev-fp8-all-in-one.safetensors"
    unet, clip, vae = CheckpointLoaderSimple.load_checkpoint(checkpoint_path)

def closestNumber(n, m):
    q = int(n / m)
    n1 = m * q
    if (n * m) > 0:
        n2 = m * (q + 1)
    else:
        n2 = m * (q - 1)
    if abs(n - n1) < abs(n - n2):
        return n1
    return n2

@torch.inference_mode()
def generate(positive_prompt, width, height, seed, steps, sampler_name, scheduler, guidance, lora_strength_model, lora_strength_clip):
    global unet, clip
    if seed == 0:
        seed = random.randint(0, 18446744073709551615)
    print(f"Seed used: {seed}")

    # Load LoRA models
    lora_path = "models/loras/flux_realism_lora.safetensors"
    unet_lora, clip_lora = LoraLoader.load_lora(unet, clip, lora_path, lora_strength_model, lora_strength_clip)

    # Encode the prompt
    cond, pooled = clip_lora.encode_from_tokens(clip_lora.tokenize(positive_prompt), return_pooled=True)
    cond = [[cond, {"pooled_output": pooled}]]
    cond = FluxGuidance.append(cond, guidance)[0]

    # Generate noise
    noise = RandomNoise.get_noise(seed)[0]

    # Get guider and sampler
    guider = BasicGuider.get_guider(unet_lora, cond)[0]
    sampler = KSamplerSelect.get_sampler(sampler_name)[0]

    # Get scheduling sigmas
    sigmas = BasicScheduler.get_sigmas(unet_lora, scheduler, steps, 1.0)[0]

    # Generate latent image
    latent_image = EmptyLatentImage.generate(closestNumber(width, 16), closestNumber(height, 16))[0]

    # Sample and decode the image
    sample, sample_denoised = SamplerCustomAdvanced.sample(noise, guider, sampler, sigmas, latent_image)
    decoded = VAEDecode.decode(vae, sample)[0].detach()

    # Convert to image and return
    return Image.fromarray(np.array(decoded * 255, dtype=np.uint8)[0])

# Define Gradio interface
with gr.Blocks(analytics_enabled=False) as demo:
    with gr.Row():
        with gr.Column():
            positive_prompt = gr.Textbox(
                lines=3, 
                interactive=True, 
                value="cute anime girl with massive fluffy fennec ears and a big fluffy tail blonde messy long hair blue eyes wearing a maid outfit with a long black dress with a gold leaf pattern and a white apron eating a slice of an apple pie in the kitchen of an old dark victorian mansion with a bright window and very expensive stuff everywhere", 
                label="Prompt"
            )
            width = gr.Slider(minimum=256, maximum=2048, value=1024, step=16, label="width")
            height = gr.Slider(minimum=256, maximum=2048, value=1024, step=16, label="height")
            seed = gr.Slider(minimum=0, maximum=18446744073709551615, value=0, step=1, label="seed (0=random)")
            steps = gr.Slider(minimum=4, maximum=50, value=20, step=1, label="steps")
            guidance = gr.Slider(minimum=0, maximum=20, value=3.5, step=0.5, label="guidance")
            lora_strength_model = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.1, label="lora_strength_model")
            lora_strength_clip = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.1, label="lora_strength_clip")
            sampler_name = gr.Dropdown(
                ["euler", "heun", "heunpp2", "dpm_2", "lms", "dpmpp_2m", "ipndm", "deis", "ddim", "uni_pc", "uni_pc_bh2"], 
                label="sampler_name", 
                value="euler"
            )
            scheduler = gr.Dropdown(
                ["normal", "sgm_uniform", "simple", "ddim_uniform"], 
                label="scheduler", 
                value="simple"
            )
            generate_button = gr.Button("Generate")
        with gr.Column():
            output_image = gr.Image(label="Generated image", interactive=False)

    generate_button.click(
        fn=generate, 
        inputs=[
            positive_prompt, width, height, seed, steps, 
            sampler_name, scheduler, guidance, 
            lora_strength_model, lora_strength_clip
        ], 
        outputs=output_image
    )

demo.queue().launch(inline=False, share=True, debug=True)