fluxhdupscaler / app.py
comrender's picture
Update app.py
a49f337 verified
raw
history blame
10.6 kB
import os
import random
import warnings
import gc
import gradio as gr
import numpy as np
import spaces
import torch
from diffusers import FluxImg2ImgPipeline
from gradio_imageslider import ImageSlider
from PIL import Image
from huggingface_hub import snapshot_download
import requests
# ESRGAN imports
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils import img2tensor, tensor2img
css = """
#col-container {
margin: 0 auto;
max-width: 800px;
}
.main-header {
text-align: center;
margin-bottom: 2rem;
}
"""
# Get HuggingFace token
huggingface_token = os.getenv("HF_TOKEN")
# Download FLUX model if not already cached
print("πŸ“₯ Downloading FLUX model...")
model_path = snapshot_download(
repo_id="black-forest-labs/FLUX.1-dev",
repo_type="model",
ignore_patterns=["*.md", "*.gitattributes"],
local_dir="FLUX.1-dev",
token=huggingface_token,
)
# Load FLUX pipeline on CPU initially
print("πŸ“₯ Loading FLUX Img2Img pipeline...")
pipe = FluxImg2ImgPipeline.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
use_safetensors=True
)
# Enable memory optimizations
pipe.enable_vae_tiling()
pipe.enable_vae_slicing()
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
# Download and load ESRGAN 4x-UltraSharp model
print("πŸ“₯ Loading ESRGAN 4x-UltraSharp...")
esrgan_path = "4x-UltraSharp.pth"
if not os.path.exists(esrgan_path):
print("Downloading ESRGAN model...")
url = "https://huggingface.co/uwg/upscaler/resolve/main/ESRGAN/4x-UltraSharp.pth"
response = requests.get(url)
with open(esrgan_path, "wb") as f:
f.write(response.content)
# Initialize ESRGAN model
esrgan_model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4
)
state_dict = torch.load(esrgan_path, map_location='cpu')
if 'params_ema' in state_dict:
state_dict = state_dict['params_ema']
elif 'params' in state_dict:
state_dict = state_dict['params']
esrgan_model.load_state_dict(state_dict)
esrgan_model.eval()
print("βœ… All models loaded successfully!")
MAX_SEED = 1000000
MAX_INPUT_SIZE = 512 # Max input size before upscaling
def make_multiple_16(n):
"""Round to nearest multiple of 16 for FLUX requirements"""
return ((n + 15) // 16) * 16
def truncate_prompt(prompt, max_tokens=75):
"""Truncate prompt to avoid CLIP token limit (77 tokens)"""
if not prompt:
return ""
# Simple truncation by character count (rough approximation)
if len(prompt) > 250: # ~75 tokens
return prompt[:250] + "..."
return prompt
def prepare_image(image, max_size=MAX_INPUT_SIZE):
"""Prepare image for processing"""
w, h = image.size
# Limit input size
if w > max_size or h > max_size:
image.thumbnail((max_size, max_size), Image.LANCZOS)
return image
def esrgan_upscale(image):
"""Upscale image 4x using ESRGAN"""
# Convert PIL to tensor
img_np = np.array(image).astype(np.float32) / 255.
img_tensor = img2tensor(img_np, bgr2rgb=False, float32=True)
# Upscale
with torch.no_grad():
output = esrgan_model(img_tensor.unsqueeze(0).cpu())
# Convert back to PIL
output_np = tensor2img(output.squeeze(0), rgb2bgr=False, min_max=(0, 1))
return Image.fromarray(output_np)
@spaces.GPU(duration=60) # 60 seconds should be enough
def enhance_image(
input_image,
prompt,
seed,
randomize_seed,
num_inference_steps,
denoising_strength,
progress=gr.Progress(track_tqdm=True),
):
"""Main enhancement function"""
if input_image is None:
raise gr.Error("Please upload an image")
# Clear memory
torch.cuda.empty_cache()
gc.collect()
try:
# Randomize seed if needed
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Prepare and validate prompt
prompt = truncate_prompt(prompt.strip() if prompt else "high quality, detailed")
# Prepare input image
input_image = prepare_image(input_image)
original_size = input_image.size
# Step 1: ESRGAN upscale (4x) on CPU
gr.Info("πŸ” Upscaling with ESRGAN 4x...")
with torch.no_grad():
# Move ESRGAN to GPU for faster processing
esrgan_model.to("cuda")
# Convert image for ESRGAN
img_np = np.array(input_image).astype(np.float32) / 255.
img_tensor = img2tensor(img_np, bgr2rgb=False, float32=True)
img_tensor = img_tensor.unsqueeze(0).to("cuda")
# Upscale
output_tensor = esrgan_model(img_tensor)
# Convert back to PIL
output_np = tensor2img(output_tensor.squeeze(0).cpu(), rgb2bgr=False, min_max=(0, 1))
upscaled_image = Image.fromarray(output_np)
# Move ESRGAN back to CPU to free memory
esrgan_model.to("cpu")
torch.cuda.empty_cache()
# Ensure dimensions are multiples of 16 for FLUX
w, h = upscaled_image.size
new_w = make_multiple_16(w)
new_h = make_multiple_16(h)
if new_w != w or new_h != h:
# Pad image to meet requirements
padded = Image.new('RGB', (new_w, new_h))
padded.paste(upscaled_image, (0, 0))
upscaled_image = padded
# Step 2: FLUX enhancement
gr.Info("🎨 Enhancing with FLUX...")
# Move pipeline to GPU
pipe.to("cuda")
# Generate with FLUX
generator = torch.Generator(device="cuda").manual_seed(seed)
with torch.inference_mode():
result = pipe(
prompt=prompt,
image=upscaled_image,
strength=denoising_strength,
num_inference_steps=num_inference_steps,
guidance_scale=1.0, # Fixed at 1.0 for FLUX dev
height=new_h,
width=new_w,
generator=generator,
).images[0]
# Crop back if we padded
if new_w != w or new_h != h:
result = result.crop((0, 0, w, h))
# Move pipeline back to CPU
pipe.to("cpu")
torch.cuda.empty_cache()
gc.collect()
# Prepare images for slider (before/after)
input_resized = input_image.resize(result.size, Image.LANCZOS)
gr.Info("βœ… Enhancement complete!")
return [input_resized, result], seed
except Exception as e:
# Cleanup on error
pipe.to("cpu")
esrgan_model.to("cpu")
torch.cuda.empty_cache()
gc.collect()
raise gr.Error(f"Enhancement failed: {str(e)}")
# Create Gradio interface
with gr.Blocks(css=css) as demo:
gr.HTML("""
<div class="main-header">
<h1>πŸš€ ESRGAN 4x + FLUX Enhancement</h1>
<p>Upload an image to upscale 4x with ESRGAN and enhance with FLUX</p>
<p>Optimized for <strong>ZeroGPU</strong> | Max input: 512x512 β†’ Output: 2048x2048</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
# Input section
input_image = gr.Image(
label="Input Image",
type="pil",
height=256
)
prompt = gr.Textbox(
label="Enhancement Prompt",
placeholder="Describe the desired enhancement (e.g., 'high quality, sharp details, vibrant colors')",
value="high quality, ultra detailed, sharp",
lines=2
)
with gr.Accordion("Advanced Settings", open=False):
num_inference_steps = gr.Slider(
label="Enhancement Steps",
minimum=10,
maximum=25,
step=1,
value=18,
info="More steps = better quality but slower"
)
denoising_strength = gr.Slider(
label="Enhancement Strength",
minimum=0.1,
maximum=0.6,
step=0.05,
value=0.35,
info="Higher = more changes to the image"
)
with gr.Row():
randomize_seed = gr.Checkbox(
label="Randomize seed",
value=True
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42
)
enhance_btn = gr.Button(
"🎨 Enhance Image (4x Upscale)",
variant="primary",
size="lg"
)
with gr.Column(scale=2):
# Output section
result_slider = ImageSlider(
type="pil",
label="Before / After",
interactive=False,
height=512
)
used_seed = gr.Number(
label="Seed Used",
interactive=False,
visible=False
)
# Examples
gr.Examples(
examples=[
["high quality, ultra detailed, sharp"],
["cinematic, professional photography, enhanced details"],
["vibrant colors, high contrast, sharp focus"],
["photorealistic, 8k quality, fine details"],
],
inputs=[prompt],
label="Example Prompts"
)
# Event handler
enhance_btn.click(
fn=enhance_image,
inputs=[
input_image,
prompt,
seed,
randomize_seed,
num_inference_steps,
denoising_strength,
],
outputs=[result_slider, used_seed]
)
gr.HTML("""
<div style="margin-top: 2rem; text-align: center; color: #666;">
<p>πŸ“Œ Pipeline: ESRGAN 4x-UltraSharp β†’ FLUX Dev Enhancement</p>
<p>⚑ Optimized for ZeroGPU with automatic memory management</p>
</div>
""")
if __name__ == "__main__":
demo.queue(max_size=3).launch(
share=False,
server_name="0.0.0.0",
server_port=7860
)