fluxhdupscaler / app.py
comrender's picture
Update app.py
3bb8a2e verified
raw
history blame
13.8 kB
import os
import random
import warnings
import gc
import gradio as gr
import numpy as np
import spaces
import torch
import torch.nn as nn
from diffusers import FluxImg2ImgPipeline
from gradio_imageslider import ImageSlider
from PIL import Image
from huggingface_hub import snapshot_download
import requests
# Minimal ESRGAN implementation (without basicsr dependency)
class ResidualDenseBlock(nn.Module):
def __init__(self, num_feat=64, num_grow_ch=32):
super(ResidualDenseBlock, self).__init__()
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(nn.Module):
def __init__(self, num_feat, num_grow_ch=32):
super(RRDB, self).__init__()
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
def forward(self, x):
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
return out * 0.2 + x
class RRDBNet(nn.Module):
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4):
super(RRDBNet, self).__init__()
self.scale = scale
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.body = nn.Sequential(*[RRDB(num_feat, num_grow_ch) for _ in range(num_block)])
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
# Upsampling
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.conv_first(x)
trunk = self.conv_body(self.body(fea))
fea = fea + trunk
fea = self.lrelu(self.conv_up1(nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
fea = self.lrelu(self.conv_up2(nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.conv_hr(fea)))
return out
css = """
#col-container {
margin: 0 auto;
max-width: 800px;
}
.main-header {
text-align: center;
margin-bottom: 2rem;
}
"""
# Get HuggingFace token
huggingface_token = os.getenv("HF_TOKEN")
# Download FLUX model if not already cached
print("πŸ“₯ Downloading FLUX model...")
model_path = snapshot_download(
repo_id="black-forest-labs/FLUX.1-dev",
repo_type="model",
ignore_patterns=["*.md", "*.gitattributes"],
local_dir="FLUX.1-dev",
token=huggingface_token,
)
# Load FLUX pipeline on CPU initially
print("πŸ“₯ Loading FLUX Img2Img pipeline...")
pipe = FluxImg2ImgPipeline.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
use_safetensors=True
)
# Enable memory optimizations
pipe.enable_vae_tiling()
pipe.enable_vae_slicing()
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
# Download and load ESRGAN 4x-UltraSharp model
print("πŸ“₯ Loading ESRGAN 4x-UltraSharp...")
esrgan_path = "4x-UltraSharp.pth"
if not os.path.exists(esrgan_path):
print("Downloading ESRGAN model...")
url = "https://huggingface.co/uwg/upscaler/resolve/main/ESRGAN/4x-UltraSharp.pth"
response = requests.get(url)
with open(esrgan_path, "wb") as f:
f.write(response.content)
# Initialize ESRGAN model
esrgan_model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4
)
# Load state dict
state_dict = torch.load(esrgan_path, map_location='cpu')
if 'params_ema' in state_dict:
state_dict = state_dict['params_ema']
elif 'params' in state_dict:
state_dict = state_dict['params']
# Clean state dict keys if needed
cleaned_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
cleaned_state_dict[k[7:]] = v
else:
cleaned_state_dict[k] = v
esrgan_model.load_state_dict(cleaned_state_dict, strict=False)
esrgan_model.eval()
print("βœ… All models loaded successfully!")
MAX_SEED = 1000000
MAX_INPUT_SIZE = 512 # Max input size before upscaling
def make_multiple_16(n):
"""Round to nearest multiple of 16 for FLUX requirements"""
return ((n + 15) // 16) * 16
def truncate_prompt(prompt, max_tokens=75):
"""Truncate prompt to avoid CLIP token limit (77 tokens)"""
if not prompt:
return ""
# Simple truncation by character count (rough approximation)
if len(prompt) > 250: # ~75 tokens
return prompt[:250] + "..."
return prompt
def prepare_image(image, max_size=MAX_INPUT_SIZE):
"""Prepare image for processing"""
w, h = image.size
# Limit input size
if w > max_size or h > max_size:
image.thumbnail((max_size, max_size), Image.LANCZOS)
return image
def esrgan_upscale(image, model, device='cuda', upscale_factor=4):
"""Upscale image using ESRGAN with variable factor support"""
orig_w, orig_h = image.size
pre_resize_factor = upscale_factor / 4.0
low_res_w = int(orig_w * pre_resize_factor)
low_res_h = int(orig_h * pre_resize_factor)
if low_res_w < 1 or low_res_h < 1:
raise ValueError("Upscale factor too small for image size")
low_res_image = image.resize((low_res_w, low_res_h), Image.BICUBIC) # Changed to BICUBIC for better match to training degradation
# Prepare image
img_np = np.array(low_res_image).astype(np.float32) / 255.
img_np = np.transpose(img_np, (2, 0, 1)) # HWC to CHW
img_tensor = torch.from_numpy(img_np).unsqueeze(0).to(device)
# Upscale
with torch.no_grad():
output = model(img_tensor)
output = output.squeeze(0).cpu().clamp(0, 1)
output_np = output.numpy()
output_np = np.transpose(output_np, (1, 2, 0)) # CHW to HWC
output_np = (output_np * 255).astype(np.uint8)
upscaled = Image.fromarray(output_np)
# Resize back to exact target size if needed (due to rounding)
target_w = int(orig_w * upscale_factor)
target_h = int(orig_h * upscale_factor)
if upscaled.size != (target_w, target_h):
upscaled = upscaled.resize((target_w, target_h), Image.BICUBIC) # Changed to BICUBIC
return upscaled
@spaces.GPU(duration=120) # Increased to 120 seconds
def enhance_image(
input_image,
prompt,
seed,
randomize_seed,
num_inference_steps,
denoising_strength,
upscale_factor,
progress=gr.Progress(track_tqdm=True),
):
"""Main enhancement function"""
if input_image is None:
raise gr.Error("Please upload an image")
# Clear memory
torch.cuda.empty_cache()
gc.collect()
try:
# Randomize seed if needed
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Prepare and validate prompt
prompt = truncate_prompt(prompt.strip() if prompt else "high quality, detailed")
# Prepare input image
input_image = prepare_image(input_image)
original_size = input_image.size
# Step 1: ESRGAN upscale on GPU
gr.Info(f"πŸ” Upscaling with ESRGAN x{upscale_factor}...")
# Move ESRGAN to GPU for faster processing
esrgan_model.to("cuda")
upscaled_image = esrgan_upscale(input_image, esrgan_model, device="cuda", upscale_factor=upscale_factor)
# Move ESRGAN back to CPU to free memory
esrgan_model.to("cpu")
torch.cuda.empty_cache()
# Ensure dimensions are multiples of 16 for FLUX
w, h = upscaled_image.size
new_w = make_multiple_16(w)
new_h = make_multiple_16(h)
if new_w != w or new_h != h:
# Pad image to meet requirements
padded = Image.new('RGB', (new_w, new_h))
padded.paste(upscaled_image, (0, 0))
upscaled_image = padded
# Step 2: FLUX enhancement
gr.Info("🎨 Enhancing with FLUX...")
# Move pipeline to GPU
pipe.to("cuda")
# Generate with FLUX
generator = torch.Generator(device="cuda").manual_seed(seed)
with torch.inference_mode():
result = pipe(
prompt=prompt,
image=upscaled_image,
strength=denoising_strength,
num_inference_steps=num_inference_steps,
guidance_scale=3.5, # Recommended for FLUX.1-dev to reduce artifacts
height=new_h,
width=new_w,
generator=generator,
).images[0]
# Crop back if we padded
if new_w != w or new_h != h:
result = result.crop((0, 0, w, h))
# Move pipeline back to CPU
pipe.to("cpu")
torch.cuda.empty_cache()
gc.collect()
# Prepare images for slider (before/after)
input_resized = input_image.resize(result.size, Image.LANCZOS)
gr.Info("βœ… Enhancement complete!")
return [input_resized, result], seed
except Exception as e:
# Cleanup on error
pipe.to("cpu")
esrgan_model.to("cpu")
torch.cuda.empty_cache()
gc.collect()
raise gr.Error(f"Enhancement failed: {str(e)}")
# Create Gradio interface
with gr.Blocks(css=css) as demo:
gr.HTML("""
<div class="main-header">
<h1>πŸš€ Flux Dev Ultimate Upscaler</h1>
<p>Upload an image to upscale 2-4x with ESRGAN and enhance with FLUX</p>
<p>Optimized for <strong>ZeroGPU</strong> | Max input: 512x512 β†’ Output: up to 2048x2048</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
# Input section
input_image = gr.Image(
label="Input Image",
type="pil",
height=256
)
prompt = gr.Textbox(
label="Describe image with prompt",
placeholder="Describe the desired enhancement (e.g., 'high quality, sharp details, vibrant colors')",
value="high quality, ultra detailed, sharp",
lines=2
)
# Advanced Settings (always open)
upscale_factor = gr.Slider(
label="Upscale Ratio",
minimum=2,
maximum=4,
step=1,
value=4,
info="Choose upscale factor (2x, 3x, 4x). Use 4x for best results; lower may cause color artifacts."
)
num_inference_steps = gr.Slider(
label="Enhancement Steps",
minimum=10,
maximum=25,
step=1,
value=20, # Increased default for better denoising
info="More steps = better quality but slower"
)
denoising_strength = gr.Slider(
label="Creativity (Denoising)",
minimum=0.1,
maximum=0.6,
step=0.05,
value=0.35,
info="Higher = more changes to the image"
)
with gr.Row():
randomize_seed = gr.Checkbox(
label="Randomize seed",
value=True
)
seed = gr.Number(
label="Seed",
value=42
)
enhance_btn = gr.Button(
"Upscale",
variant="primary",
size="lg"
)
with gr.Column(scale=2):
# Output section
result_slider = ImageSlider(
type="pil",
label="Before / After",
interactive=False,
height=512
)
used_seed = gr.Number(
label="Seed Used",
interactive=False,
visible=False
)
# Event handler
enhance_btn.click(
fn=enhance_image,
inputs=[
input_image,
prompt,
seed,
randomize_seed,
num_inference_steps,
denoising_strength,
upscale_factor,
],
outputs=[result_slider, used_seed]
)
gr.HTML("""
<div style="margin-top: 2rem; text-align: center; color: #666;">
<p>πŸ“Œ Pipeline: ESRGAN 2-4x-UltraSharp β†’ FLUX Dev Enhancement</p>
<p>⚑ Optimized for ZeroGPU with automatic memory management</p>
<p>πŸ“Œ Note: User is responsible for obtaining commercial license from Flux Dev if using image commercially under their license.</p>
</div>
""")
if __name__ == "__main__":
demo.queue(max_size=3).launch(
share=False,
server_name="0.0.0.0",
server_port=7860
)