Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import random | |
import sys | |
from typing import Sequence, Mapping, Any, Union | |
import torch | |
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
import spaces | |
# Download required models from Hugging Face | |
hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="ae.safetensors", local_dir="models/vae") | |
hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="flux1-dev.safetensors", local_dir="models/diffusion_models") | |
hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="clip_l.safetensors", local_dir="models/text_encoders") | |
hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="t5xxl_fp16.safetensors", local_dir="models/text_encoders") | |
hf_hub_download(repo_id="kim2091/UltraSharp", filename="4x-UltraSharp.pth", local_dir="models/upscale_models") | |
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: | |
"""Returns the value at the given index of a sequence or mapping.""" | |
try: | |
return obj[index] | |
except KeyError: | |
return obj["result"][index] | |
def find_path(name: str, path: str = None) -> str: | |
"""Recursively looks at parent folders starting from the given path until it finds the given name.""" | |
if path is None: | |
path = os.getcwd() | |
if name in os.listdir(path): | |
path_name = os.path.join(path, name) | |
print(f"{name} found: {path_name}") | |
return path_name | |
parent_directory = os.path.dirname(path) | |
if parent_directory == path: | |
return None | |
return find_path(name, parent_directory) | |
def add_comfyui_directory_to_sys_path() -> None: | |
"""Add 'ComfyUI' to the sys.path""" | |
comfyui_path = find_path("ComfyUI") | |
if comfyui_path is not None and os.path.isdir(comfyui_path): | |
sys.path.append(comfyui_path) | |
print(f"'{comfyui_path}' added to sys.path") | |
def add_extra_model_paths() -> None: | |
"""Parse the optional extra_model_paths.yaml file and add the parsed paths to the sys.path.""" | |
try: | |
from main import load_extra_path_config | |
extra_model_paths = find_path("extra_model_paths.yaml") | |
if extra_model_paths is not None: | |
load_extra_path_config(extra_model_paths) | |
else: | |
print("Could not find the extra_model_paths config file.") | |
except ImportError: | |
try: | |
from utils.extra_config import load_extra_path_config | |
extra_model_paths = find_path("extra_model_paths.yaml") | |
if extra_model_paths is not None: | |
load_extra_path_config(extra_model_paths) | |
else: | |
print("Could not find the extra_model_paths config file.") | |
except ImportError: | |
print("Could not import extra config. Continuing without extra model paths.") | |
add_comfyui_directory_to_sys_path() | |
try: | |
add_extra_model_paths() | |
except Exception as e: | |
print(f"Warning: Could not load extra model paths: {e}") | |
def import_custom_nodes() -> None: | |
"""Find all custom nodes in the custom_nodes folder and add those node objects to NODE_CLASS_MAPPINGS""" | |
try: | |
import asyncio | |
import execution | |
from nodes import init_extra_nodes | |
import server | |
# Check if we're already in an event loop | |
try: | |
loop = asyncio.get_event_loop() | |
if loop.is_running(): | |
# We're in an existing loop, use it | |
pass | |
else: | |
# Loop exists but not running, set a new one | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
except RuntimeError: | |
# No loop exists, create one | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
server_instance = server.PromptServer(loop) | |
execution.PromptQueue(server_instance) | |
init_extra_nodes() | |
except Exception as e: | |
print(f"Warning: Could not initialize custom nodes: {e}") | |
print("Continuing with basic ComfyUI nodes only...") | |
from nodes import NODE_CLASS_MAPPINGS | |
# Pre-load models outside the decorated function for ZeroGPU efficiency | |
try: | |
import_custom_nodes() | |
# Initialize model loaders | |
dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]() | |
dualcliploader_54 = dualcliploader.load_clip( | |
clip_name1="clip_l.safetensors", | |
clip_name2="t5xxl_fp16.safetensors", | |
type="flux", | |
device="default", | |
) | |
upscalemodelloader = NODE_CLASS_MAPPINGS["UpscaleModelLoader"]() | |
upscalemodelloader_44 = upscalemodelloader.load_model(model_name="4x-UltraSharp.pth") | |
vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]() | |
vaeloader_55 = vaeloader.load_vae(vae_name="ae.safetensors") | |
unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]() | |
unetloader_58 = unetloader.load_unet( | |
unet_name="flux1-dev.safetensors", weight_dtype="default" | |
) | |
downloadandloadflorence2model = NODE_CLASS_MAPPINGS["DownloadAndLoadFlorence2Model"]() | |
downloadandloadflorence2model_52 = downloadandloadflorence2model.loadmodel( | |
model="microsoft/Florence-2-large", precision="fp16", attention="sdpa" | |
) | |
# Pre-load models to GPU for efficiency | |
try: | |
from comfy import model_management | |
model_loaders = [dualcliploader_54, vaeloader_55, unetloader_58, downloadandloadflorence2model_52] | |
valid_models = [ | |
getattr(loader[0], 'patcher', loader[0]) | |
for loader in model_loaders | |
if not isinstance(loader[0], dict) and not isinstance(getattr(loader[0], 'patcher', None), dict) | |
] | |
model_management.load_models_gpu(valid_models) | |
print("Models successfully pre-loaded to GPU") | |
except Exception as e: | |
print(f"Warning: Could not pre-load models to GPU: {e}") | |
print("ComfyUI setup completed successfully!") | |
except Exception as e: | |
print(f"Error during ComfyUI setup: {e}") | |
print("Please check that all required custom nodes are installed.") | |
raise | |
# Adjust duration based on your workflow speed | |
def enhance_image(image_input, upscale_factor, steps, cfg_scale, denoise_strength, guidance_scale): | |
""" | |
Main function to enhance and upscale images using Florence-2 captioning and FLUX upscaling | |
""" | |
try: | |
with torch.inference_mode(): | |
# Handle different input types (file upload vs URL) | |
if isinstance(image_input, str) and image_input.startswith(('http://', 'https://')): | |
# Load from URL | |
load_image_from_url_mtb = NODE_CLASS_MAPPINGS["Load Image From Url (mtb)"]() | |
load_image_result = load_image_from_url_mtb.load(url=image_input) | |
else: | |
# Load from uploaded file | |
loadimage = NODE_CLASS_MAPPINGS["LoadImage"]() | |
load_image_result = loadimage.load_image(image=image_input) | |
# Generate detailed caption using Florence-2 | |
florence2run = NODE_CLASS_MAPPINGS["Florence2Run"]() | |
florence2run_51 = florence2run.encode( | |
text_input="", | |
task="more_detailed_caption", | |
fill_mask=True, | |
keep_model_loaded=False, | |
max_new_tokens=1024, | |
num_beams=3, | |
do_sample=True, | |
output_mask_select="", | |
seed=random.randint(1, 2**64), | |
image=get_value_at_index(load_image_result, 0), | |
florence2_model=get_value_at_index(downloadandloadflorence2model_52, 0), | |
) | |
# Encode the generated caption | |
cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]() | |
cliptextencode_6 = cliptextencode.encode( | |
text=get_value_at_index(florence2run_51, 2), | |
clip=get_value_at_index(dualcliploader_54, 0), | |
) | |
# Encode empty negative prompt | |
cliptextencode_42 = cliptextencode.encode( | |
text="", clip=get_value_at_index(dualcliploader_54, 0) | |
) | |
# Set up upscale factor | |
primitivefloat = NODE_CLASS_MAPPINGS["PrimitiveFloat"]() | |
primitivefloat_60 = primitivefloat.execute(value=upscale_factor) | |
# Apply FLUX guidance | |
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]() | |
fluxguidance_26 = fluxguidance.append( | |
guidance=guidance_scale, | |
conditioning=get_value_at_index(cliptextencode_6, 0) | |
) | |
# Perform ultimate upscaling | |
ultimatesdupscale = NODE_CLASS_MAPPINGS["UltimateSDUpscale"]() | |
ultimatesdupscale_50 = ultimatesdupscale.upscale( | |
upscale_by=get_value_at_index(primitivefloat_60, 0), | |
seed=random.randint(1, 2**64), | |
steps=steps, | |
cfg=cfg_scale, | |
sampler_name="euler", | |
scheduler="normal", | |
denoise=denoise_strength, | |
mode_type="Linear", | |
tile_width=1024, | |
tile_height=1024, | |
mask_blur=8, | |
tile_padding=32, | |
seam_fix_mode="None", | |
seam_fix_denoise=1, | |
seam_fix_width=64, | |
seam_fix_mask_blur=8, | |
seam_fix_padding=16, | |
force_uniform_tiles=True, | |
tiled_decode=False, | |
image=get_value_at_index(load_image_result, 0), | |
model=get_value_at_index(unetloader_58, 0), | |
positive=get_value_at_index(fluxguidance_26, 0), | |
negative=get_value_at_index(cliptextencode_42, 0), | |
vae=get_value_at_index(vaeloader_55, 0), | |
upscale_model=get_value_at_index(upscalemodelloader_44, 0), | |
) | |
# Save the result | |
saveimage = NODE_CLASS_MAPPINGS["SaveImage"]() | |
saveimage_43 = saveimage.save_images( | |
filename_prefix="enhanced_image", | |
images=get_value_at_index(ultimatesdupscale_50, 0), | |
) | |
# Return the path to the saved image | |
saved_path = f"output/{saveimage_43['ui']['images'][0]['filename']}" | |
# Also return the generated caption for user feedback | |
generated_caption = get_value_at_index(florence2run_51, 2) | |
return saved_path, generated_caption | |
except Exception as e: | |
print(f"Error in enhance_image: {str(e)}") | |
raise gr.Error(f"Enhancement failed: {str(e)}") | |
# Create the Gradio interface | |
def create_interface(): | |
with gr.Blocks( | |
title="π AI Image Enhancer - Florence-2 + FLUX", | |
theme=gr.themes.Soft(), | |
css=""" | |
.gradio-container { | |
max-width: 1200px !important; | |
} | |
.main-header { | |
text-align: center; | |
margin-bottom: 2rem; | |
} | |
.result-gallery { | |
min-height: 400px; | |
} | |
""" | |
) as app: | |
gr.HTML(""" | |
<div class="main-header"> | |
<h1>π¨ AI Image Enhancer</h1> | |
<p>Upload an image or provide a URL to enhance it using Florence-2 captioning and FLUX upscaling</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.HTML("<h3>π€ Input Settings</h3>") | |
with gr.Tabs(): | |
with gr.TabItem("π Upload Image"): | |
image_upload = gr.Image( | |
label="Upload Image", | |
type="filepath", | |
height=300 | |
) | |
with gr.TabItem("π Image URL"): | |
image_url = gr.Textbox( | |
label="Image URL", | |
placeholder="https://example.com/image.jpg", | |
value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg" | |
) | |
gr.HTML("<h3>βοΈ Enhancement Settings</h3>") | |
upscale_factor = gr.Slider( | |
minimum=1.0, | |
maximum=4.0, | |
value=2.0, | |
step=0.5, | |
label="Upscale Factor", | |
info="How much to upscale the image" | |
) | |
steps = gr.Slider( | |
minimum=10, | |
maximum=50, | |
value=25, | |
step=5, | |
label="Steps", | |
info="Number of denoising steps" | |
) | |
cfg_scale = gr.Slider( | |
minimum=0.5, | |
maximum=10.0, | |
value=1.0, | |
step=0.5, | |
label="CFG Scale", | |
info="Classifier-free guidance scale" | |
) | |
denoise_strength = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.3, | |
step=0.1, | |
label="Denoise Strength", | |
info="How much to denoise the image" | |
) | |
guidance_scale = gr.Slider( | |
minimum=1.0, | |
maximum=10.0, | |
value=3.5, | |
step=0.5, | |
label="Guidance Scale", | |
info="FLUX guidance strength" | |
) | |
enhance_btn = gr.Button( | |
"π Enhance Image", | |
variant="primary", | |
size="lg" | |
) | |
with gr.Column(scale=1): | |
gr.HTML("<h3>π Results</h3>") | |
output_image = gr.Image( | |
label="Enhanced Image", | |
type="filepath", | |
height=400, | |
interactive=False | |
) | |
generated_caption = gr.Textbox( | |
label="Generated Caption", | |
placeholder="The AI-generated caption will appear here...", | |
lines=3, | |
interactive=False | |
) | |
gr.HTML(""" | |
<div style="margin-top: 1rem; padding: 1rem; background: #f0f0f0; border-radius: 8px;"> | |
<h4>π‘ How it works:</h4> | |
<ol> | |
<li>Florence-2 analyzes your image and generates a detailed caption</li> | |
<li>FLUX uses this caption to guide the upscaling process</li> | |
<li>The result is an enhanced, higher-resolution image</li> | |
</ol> | |
</div> | |
""") | |
# Event handlers | |
def process_image(img_upload, img_url, upscale_f, steps_val, cfg_val, denoise_val, guidance_val): | |
# Determine input source | |
image_input = img_upload if img_upload is not None else img_url | |
if not image_input: | |
raise gr.Error("Please provide an image (upload or URL)") | |
return enhance_image(image_input, upscale_f, steps_val, cfg_val, denoise_val, guidance_val) | |
enhance_btn.click( | |
fn=process_image, | |
inputs=[ | |
image_upload, | |
image_url, | |
upscale_factor, | |
steps, | |
cfg_scale, | |
denoise_strength, | |
guidance_scale | |
], | |
outputs=[output_image, generated_caption] | |
) | |
# Example inputs | |
gr.Examples( | |
examples=[ | |
[None, "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg", 2.0, 25, 1.0, 0.3, 3.5], | |
[None, "https://picsum.photos/512/512", 2.0, 20, 1.5, 0.4, 4.0], | |
], | |
inputs=[ | |
image_upload, | |
image_url, | |
upscale_factor, | |
steps, | |
cfg_scale, | |
denoise_strength, | |
guidance_scale | |
] | |
) | |
return app | |
if __name__ == "__main__": | |
app = create_interface() | |
app.launch(share=True, server_name="0.0.0.0", server_port=7860) |