Chroma / app.py
gokaygokay's picture
Update app.py
f064a5b verified
raw
history blame
6.61 kB
import os
import random
import sys
from typing import Sequence, Mapping, Any, Union
import torch
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download
import spaces
from comfy import model_management
# Download required models
t5_path = hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="t5xxl_fp8_e4m3fn.safetensors", local_dir="models/text_encoders/")
vae_path = hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="ae.safetensors", local_dir="models/vae")
unet_path = hf_hub_download(repo_id="lodestones/Chroma", filename="chroma-unlocked-v31.safetensors", local_dir="models/unet")
# Import the workflow functions
from my_workflow import (
get_value_at_index,
add_comfyui_directory_to_sys_path,
add_extra_model_paths,
import_custom_nodes,
NODE_CLASS_MAPPINGS,
CLIPTextEncode,
CLIPLoader,
VAEDecode,
UNETLoader,
VAELoader,
SaveImage,
)
# Initialize ComfyUI
add_comfyui_directory_to_sys_path()
add_extra_model_paths()
import_custom_nodes()
# Initialize all model loaders outside the function
randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]()
emptysd3latentimage = NODE_CLASS_MAPPINGS["EmptySD3LatentImage"]()
ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
cliploader = CLIPLoader()
t5tokenizeroptions = NODE_CLASS_MAPPINGS["T5TokenizerOptions"]()
cliptextencode = CLIPTextEncode()
unetloader = UNETLoader()
vaeloader = VAELoader()
cfgguider = NODE_CLASS_MAPPINGS["CFGGuider"]()
basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]()
samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
vaedecode = VAEDecode()
saveimage = SaveImage()
# Load models
cliploader_78 = cliploader.load_clip(
clip_name="t5xxl_fp8_e4m3fn.safetensors", type="chroma", device="default"
)
t5tokenizeroptions_82 = t5tokenizeroptions.set_options(
min_padding=1, min_length=0, clip=get_value_at_index(cliploader_78, 0)
)
unetloader_76 = unetloader.load_unet(
unet_name="chroma-unlocked-v31.safetensors", weight_dtype="fp8_e4m3fn"
)
vaeloader_80 = vaeloader.load_vae(vae_name="ae.safetensors")
# Add all the models that load a safetensors file
model_loaders = [cliploader_78, unetloader_76, vaeloader_80]
# Check which models are valid and how to best load them
valid_models = [
getattr(loader[0], 'patcher', loader[0])
for loader in model_loaders
if not isinstance(loader[0], dict) and not isinstance(getattr(loader[0], 'patcher', None), dict)
]
# Finally loads the models
model_management.load_models_gpu(valid_models)
@spaces.GPU
def generate_image(prompt, negative_prompt, width, height, steps, cfg, seed):
with torch.inference_mode():
# Set random seed if provided
if seed == -1:
seed = random.randint(1, 2**64)
random.seed(seed)
randomnoise_68 = randomnoise.get_noise(noise_seed=seed)
emptysd3latentimage_69 = emptysd3latentimage.generate(
width=width, height=height, batch_size=1
)
ksamplerselect_72 = ksamplerselect.get_sampler(sampler_name="euler")
cliptextencode_74 = cliptextencode.encode(
text=prompt,
clip=get_value_at_index(t5tokenizeroptions_82, 0),
)
cliptextencode_75 = cliptextencode.encode(
text=negative_prompt,
clip=get_value_at_index(t5tokenizeroptions_82, 0),
)
cfgguider_73 = cfgguider.get_guider(
cfg=cfg,
model=get_value_at_index(unetloader_76, 0),
positive=get_value_at_index(cliptextencode_74, 0),
negative=get_value_at_index(cliptextencode_75, 0),
)
basicscheduler_84 = basicscheduler.get_sigmas(
scheduler="beta",
steps=steps,
denoise=1,
model=get_value_at_index(unetloader_76, 0),
)
samplercustomadvanced_67 = samplercustomadvanced.sample(
noise=get_value_at_index(randomnoise_68, 0),
guider=get_value_at_index(cfgguider_73, 0),
sampler=get_value_at_index(ksamplerselect_72, 0),
sigmas=get_value_at_index(basicscheduler_84, 0),
latent_image=get_value_at_index(emptysd3latentimage_69, 0),
)
vaedecode_79 = vaedecode.decode(
samples=get_value_at_index(samplercustomadvanced_67, 0),
vae=get_value_at_index(vaeloader_80, 0),
)
# Instead of saving to file, return the image directly
return get_value_at_index(vaedecode_79, 0)
# Create Gradio interface
with gr.Blocks() as app:
gr.Markdown("# Chroma Image Generator")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here...",
lines=3
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="Enter negative prompt here...",
value="low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors",
lines=2
)
with gr.Row():
width = gr.Slider(
minimum=512,
maximum=2048,
value=1024,
step=64,
label="Width"
)
height = gr.Slider(
minimum=512,
maximum=2048,
value=1024,
step=64,
label="Height"
)
with gr.Row():
steps = gr.Slider(
minimum=1,
maximum=50,
value=26,
step=1,
label="Steps"
)
cfg = gr.Slider(
minimum=1,
maximum=20,
value=4,
step=0.5,
label="CFG Scale"
)
seed = gr.Number(
value=-1,
label="Seed (-1 for random)"
)
generate_btn = gr.Button("Generate")
with gr.Column():
output_image = gr.Image(label="Generated Image")
generate_btn.click(
fn=generate_image,
inputs=[prompt, negative_prompt, width, height, steps, cfg, seed],
outputs=[output_image]
)
if __name__ == "__main__":
app.launch(share=True)