mobenta commited on
Commit
7473c7e
·
verified ·
1 Parent(s): 6d68a8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -110
app.py CHANGED
@@ -1,127 +1,113 @@
1
- import os
2
- import sys
3
- import random
4
  import torch
5
- import numpy as np
6
- from PIL import Image
7
  import gradio as gr
 
 
 
 
8
 
9
- # Check and add the ComfyUI repository path to sys.path
10
- repo_path = './ComfyUI/totoro_extras'
11
- print(f"Checking for repository path: {repo_path}")
12
- if not os.path.exists(repo_path):
13
- raise FileNotFoundError(f"Repository path '{repo_path}' not found. Make sure the ComfyUI repository is cloned correctly.")
14
- sys.path.append(repo_path)
15
- print(f"Repository path added to sys.path: {repo_path}")
16
 
17
- # Import nodes and custom modules
18
- from nodes import NODE_CLASS_MAPPINGS
19
- from totoro_extras import nodes_custom_sampler, nodes_flux
20
 
21
- # Initialize necessary components from the nodes
22
- CheckpointLoaderSimple = NODE_CLASS_MAPPINGS["CheckpointLoaderSimple"]()
23
- LoraLoader = NODE_CLASS_MAPPINGS["LoraLoader"]()
24
- FluxGuidance = nodes_flux.NODE_CLASS_MAPPINGS["FluxGuidance"]()
25
- RandomNoise = nodes_custom_sampler.NODE_CLASS_MAPPINGS["RandomNoise"]()
26
- BasicGuider = nodes_custom_sampler.NODE_CLASS_MAPPINGS["BasicGuider"]()
27
- KSamplerSelect = nodes_custom_sampler.NODE_CLASS_MAPPINGS["KSamplerSelect"]()
28
- BasicScheduler = nodes_custom_sampler.NODE_CLASS_MAPPINGS["BasicScheduler"]()
29
- SamplerCustomAdvanced = nodes_custom_sampler.NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
30
- VAELoader = NODE_CLASS_MAPPINGS["VAELoader"]()
31
- VAEDecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
32
- EmptyLatentImage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()
33
 
34
- # Load checkpoints and models
35
- with torch.inference_mode():
36
- checkpoint_path = "models/checkpoints/flux1-dev-fp8-all-in-one.safetensors"
37
- unet, clip, vae = CheckpointLoaderSimple.load_checkpoint(checkpoint_path)
38
 
39
- def closestNumber(n, m):
40
- q = int(n / m)
41
- n1 = m * q
42
- if (n * m) > 0:
43
- n2 = m * (q + 1)
44
- else:
45
- n2 = m * (q - 1)
46
- if abs(n - n1) < abs(n - n2):
47
- return n1
48
- return n2
49
 
50
- @torch.inference_mode()
51
- def generate(positive_prompt, width, height, seed, steps, sampler_name, scheduler, guidance, lora_strength_model, lora_strength_clip):
52
- global unet, clip
53
  if seed == 0:
54
- seed = random.randint(0, 18446744073709551615)
55
- print(f"Seed used: {seed}")
56
-
57
- # Load LoRA models
58
- lora_path = "models/loras/flux_realism_lora.safetensors"
59
- unet_lora, clip_lora = LoraLoader.load_lora(unet, clip, lora_path, lora_strength_model, lora_strength_clip)
60
-
61
- # Encode the prompt
62
- cond, pooled = clip_lora.encode_from_tokens(clip_lora.tokenize(positive_prompt), return_pooled=True)
63
- cond = [[cond, {"pooled_output": pooled}]]
64
- cond = FluxGuidance.append(cond, guidance)[0]
65
-
66
- # Generate noise
67
- noise = RandomNoise.get_noise(seed)[0]
68
-
69
- # Get guider and sampler
70
- guider = BasicGuider.get_guider(unet_lora, cond)[0]
71
- sampler = KSamplerSelect.get_sampler(sampler_name)[0]
72
-
73
- # Get scheduling sigmas
74
- sigmas = BasicScheduler.get_sigmas(unet_lora, scheduler, steps, 1.0)[0]
75
-
76
- # Generate latent image
77
- latent_image = EmptyLatentImage.generate(closestNumber(width, 16), closestNumber(height, 16))[0]
78
-
79
- # Sample and decode the image
80
- sample, sample_denoised = SamplerCustomAdvanced.sample(noise, guider, sampler, sigmas, latent_image)
81
- decoded = VAEDecode.decode(vae, sample)[0].detach()
82
-
83
- # Convert to image and return
84
- return Image.fromarray(np.array(decoded * 255, dtype=np.uint8)[0])
85
-
86
- # Define Gradio interface
87
- with gr.Blocks(analytics_enabled=False) as demo:
88
  with gr.Row():
89
  with gr.Column():
90
- positive_prompt = gr.Textbox(
91
- lines=3,
92
- interactive=True,
93
- value="cute anime girl with massive fluffy fennec ears and a big fluffy tail blonde messy long hair blue eyes wearing a maid outfit with a long black dress with a gold leaf pattern and a white apron eating a slice of an apple pie in the kitchen of an old dark victorian mansion with a bright window and very expensive stuff everywhere",
94
- label="Prompt"
 
95
  )
96
- width = gr.Slider(minimum=256, maximum=2048, value=1024, step=16, label="width")
97
- height = gr.Slider(minimum=256, maximum=2048, value=1024, step=16, label="height")
98
- seed = gr.Slider(minimum=0, maximum=18446744073709551615, value=0, step=1, label="seed (0=random)")
99
- steps = gr.Slider(minimum=4, maximum=50, value=20, step=1, label="steps")
100
- guidance = gr.Slider(minimum=0, maximum=20, value=3.5, step=0.5, label="guidance")
101
- lora_strength_model = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.1, label="lora_strength_model")
102
- lora_strength_clip = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.1, label="lora_strength_clip")
103
- sampler_name = gr.Dropdown(
104
- ["euler", "heun", "heunpp2", "dpm_2", "lms", "dpmpp_2m", "ipndm", "deis", "ddim", "uni_pc", "uni_pc_bh2"],
105
- label="sampler_name",
106
- value="euler"
107
  )
108
- scheduler = gr.Dropdown(
109
- ["normal", "sgm_uniform", "simple", "ddim_uniform"],
110
- label="scheduler",
111
- value="simple"
112
- )
113
- generate_button = gr.Button("Generate")
114
  with gr.Column():
115
- output_image = gr.Image(label="Generated image", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
- generate_button.click(
118
- fn=generate,
119
- inputs=[
120
- positive_prompt, width, height, seed, steps,
121
- sampler_name, scheduler, guidance,
122
- lora_strength_model, lora_strength_clip
123
- ],
124
- outputs=output_image
125
  )
126
 
127
- demo.queue().launch(inline=False, share=True, debug=True)
 
 
 
 
1
  import torch
2
+ from diffusers import FluxPipeline
 
3
  import gradio as gr
4
+ import random
5
+ import numpy as np
6
+ import os
7
+ import spaces
8
 
9
+ # Check for GPU availability
10
+ if torch.cuda.is_available():
11
+ device = "cuda"
12
+ print("Using GPU")
13
+ else:
14
+ device = "cpu"
15
+ print("Using CPU")
16
 
17
+ # Set environment variables
18
+ HF_TOKEN = os.getenv("HF_TOKEN") # Make sure to set this in your environment
 
19
 
20
+ MAX_SEED = np.iinfo(np.int32).max
21
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # Initialize the pipeline and download the model
24
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
25
+ pipe.to(device)
 
26
 
27
+ # Enable memory optimizations
28
+ pipe.enable_attention_slicing()
 
 
 
 
 
 
 
 
29
 
30
+ # Define the image generation function
31
+ @spaces.GPU(duration=180)
32
+ def generate_image(prompt, num_inference_steps, height, width, guidance_scale, seed, num_images_per_prompt, progress=gr.Progress(track_tqdm=True)):
33
  if seed == 0:
34
+ seed = random.randint(1, MAX_SEED)
35
+
36
+ generator = torch.Generator().manual_seed(seed)
37
+
38
+ with torch.inference_mode():
39
+ output = pipe(
40
+ prompt=prompt,
41
+ num_inference_steps=num_inference_steps,
42
+ height=height,
43
+ width=width,
44
+ guidance_scale=guidance_scale,
45
+ generator=generator,
46
+ num_images_per_prompt=num_images_per_prompt
47
+ ).images
48
+
49
+ return output
50
+
51
+ # Create the Gradio interface
52
+ examples = [
53
+ ["A cat holding a sign that says hello world"],
54
+ ["a tiny astronaut hatching from an egg on the moon"],
55
+ ["An astronaut on Mars in a futuristic cyborg suit."],
56
+ ]
57
+
58
+ css = '''
59
+ .gradio-container{max-width: 1000px !important}
60
+ h1{text-align:center}
61
+ '''
62
+
63
+ with gr.Blocks(css=css) as demo:
 
 
 
 
64
  with gr.Row():
65
  with gr.Column():
66
+ gr.HTML(
67
+ """
68
+ <h1 style='text-align: center'>
69
+ FLUX.1-dev
70
+ </h1>
71
+ """
72
  )
73
+ gr.HTML(
74
+ """
75
+ Made by <a href='https://linktr.ee/Nick088' target='_blank'>Nick088</a>
76
+ <br> <a href="https://discord.gg/osai"> <img src="https://img.shields.io/discord/1198701940511617164?color=%23738ADB&label=Discord&style=for-the-badge" alt="Discord"> </a>
77
+ """
 
 
 
 
 
 
78
  )
79
+ with gr.Group():
 
 
 
 
 
80
  with gr.Column():
81
+ prompt = gr.Textbox(label="Prompt", info="Describe the image you want", placeholder="A cat...")
82
+ run_button = gr.Button("Run")
83
+ result = gr.Gallery(label="Generated AI Images", elem_id="gallery")
84
+ with gr.Accordion("Advanced options", open=False):
85
+ with gr.Row():
86
+ num_inference_steps = gr.Slider(label="Number of Inference Steps", info="The number of denoising steps of the image. More denoising steps usually lead to a higher quality image at the cost of slower inference", minimum=1, maximum=50, value=25, step=1)
87
+ guidance_scale = gr.Slider(label="Guidance Scale", info="Controls how much the image generation process follows the text prompt. Higher values make the image stick more closely to the input text.", minimum=0.0, maximum=7.0, value=3.5, step=0.1)
88
+ with gr.Row():
89
+ width = gr.Slider(label="Width", info="Width of the Image", minimum=256, maximum=1024, step=32, value=1024)
90
+ height = gr.Slider(label="Height", info="Height of the Image", minimum=256, maximum=1024, step=32, value=1024)
91
+ with gr.Row():
92
+ seed = gr.Slider(value=42, minimum=0, maximum=MAX_SEED, step=1, label="Seed", info="A starting point to initiate the generation process, put 0 for a random one")
93
+ num_images_per_prompt = gr.Slider(label="Images Per Prompt", info="Number of Images to generate with the settings", minimum=1, maximum=4, step=1, value=2)
94
+
95
+ gr.Examples(
96
+ examples=examples,
97
+ fn=generate_image,
98
+ inputs=[prompt, num_inference_steps, height, width, guidance_scale, seed, num_images_per_prompt],
99
+ outputs=[result],
100
+ cache_examples=CACHE_EXAMPLES
101
+ )
102
 
103
+ gr.on(
104
+ triggers=[
105
+ prompt.submit,
106
+ run_button.click,
107
+ ],
108
+ fn=generate_image,
109
+ inputs=[prompt, num_inference_steps, height, width, guidance_scale, seed, num_images_per_prompt],
110
+ outputs=[result],
111
  )
112
 
113
+ demo.queue().launch(share=False)