Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import random | |
import warnings | |
import gc | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import torch | |
import torch.nn as nn | |
from diffusers import FluxImg2ImgPipeline | |
from gradio_imageslider import ImageSlider | |
from PIL import Image | |
from huggingface_hub import snapshot_download | |
import requests | |
# Minimal ESRGAN implementation (without basicsr dependency) | |
class ResidualDenseBlock(nn.Module): | |
def __init__(self, num_feat=64, num_grow_ch=32): | |
super(ResidualDenseBlock, self).__init__() | |
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) | |
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) | |
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) | |
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) | |
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) | |
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
def forward(self, x): | |
x1 = self.lrelu(self.conv1(x)) | |
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) | |
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) | |
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) | |
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) | |
return x5 * 0.2 + x | |
class RRDB(nn.Module): | |
def __init__(self, num_feat, num_grow_ch=32): | |
super(RRDB, self).__init__() | |
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) | |
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) | |
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) | |
def forward(self, x): | |
out = self.rdb1(x) | |
out = self.rdb2(out) | |
out = self.rdb3(out) | |
return out * 0.2 + x | |
class RRDBNet(nn.Module): | |
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4): | |
super(RRDBNet, self).__init__() | |
self.scale = scale | |
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) | |
self.body = nn.Sequential(*[RRDB(num_feat, num_grow_ch) for _ in range(num_block)]) | |
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |
# Upsampling | |
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) | |
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
def forward(self, x): | |
fea = self.conv_first(x) | |
trunk = self.conv_body(self.body(fea)) | |
fea = fea + trunk | |
fea = self.lrelu(self.conv_up1(nn.functional.interpolate(fea, scale_factor=2, mode='nearest'))) | |
fea = self.lrelu(self.conv_up2(nn.functional.interpolate(fea, scale_factor=2, mode='nearest'))) | |
out = self.conv_last(self.lrelu(self.conv_hr(fea))) | |
return out | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 800px; | |
} | |
.main-header { | |
text-align: center; | |
margin-bottom: 2rem; | |
} | |
""" | |
# Get HuggingFace token | |
huggingface_token = os.getenv("HF_TOKEN") | |
# Download FLUX model if not already cached | |
print("π₯ Downloading FLUX model...") | |
model_path = snapshot_download( | |
repo_id="black-forest-labs/FLUX.1-dev", | |
repo_type="model", | |
ignore_patterns=["*.md", "*.gitattributes"], | |
local_dir="FLUX.1-dev", | |
token=huggingface_token, | |
) | |
# Load FLUX pipeline on CPU initially | |
print("π₯ Loading FLUX Img2Img pipeline...") | |
pipe = FluxImg2ImgPipeline.from_pretrained( | |
model_path, | |
torch_dtype=torch.bfloat16, | |
use_safetensors=True | |
) | |
# Enable memory optimizations | |
pipe.enable_vae_tiling() | |
pipe.enable_vae_slicing() | |
pipe.vae.enable_tiling() | |
pipe.vae.enable_slicing() | |
# Download and load ESRGAN 4x-UltraSharp model | |
print("π₯ Loading ESRGAN 4x-UltraSharp...") | |
esrgan_path = "4x-UltraSharp.pth" | |
if not os.path.exists(esrgan_path): | |
print("Downloading ESRGAN model...") | |
url = "https://huggingface.co/uwg/upscaler/resolve/main/ESRGAN/4x-UltraSharp.pth" | |
response = requests.get(url) | |
with open(esrgan_path, "wb") as f: | |
f.write(response.content) | |
# Initialize ESRGAN model | |
esrgan_model = RRDBNet( | |
num_in_ch=3, | |
num_out_ch=3, | |
num_feat=64, | |
num_block=23, | |
num_grow_ch=32, | |
scale=4 | |
) | |
# Load state dict | |
state_dict = torch.load(esrgan_path, map_location='cpu') | |
if 'params_ema' in state_dict: | |
state_dict = state_dict['params_ema'] | |
elif 'params' in state_dict: | |
state_dict = state_dict['params'] | |
# Clean state dict keys if needed | |
cleaned_state_dict = {} | |
for k, v in state_dict.items(): | |
if k.startswith('module.'): | |
cleaned_state_dict[k[7:]] = v | |
else: | |
cleaned_state_dict[k] = v | |
esrgan_model.load_state_dict(cleaned_state_dict, strict=False) | |
esrgan_model.eval() | |
print("β All models loaded successfully!") | |
MAX_SEED = 1000000 | |
MAX_INPUT_SIZE = 512 # Max input size before upscaling | |
def make_multiple_16(n): | |
"""Round to nearest multiple of 16 for FLUX requirements""" | |
return ((n + 15) // 16) * 16 | |
def truncate_prompt(prompt, max_tokens=75): | |
"""Truncate prompt to avoid CLIP token limit (77 tokens)""" | |
if not prompt: | |
return "" | |
# Simple truncation by character count (rough approximation) | |
if len(prompt) > 250: # ~75 tokens | |
return prompt[:250] + "..." | |
return prompt | |
def prepare_image(image, max_size=MAX_INPUT_SIZE): | |
"""Prepare image for processing""" | |
w, h = image.size | |
# Limit input size | |
if w > max_size or h > max_size: | |
image.thumbnail((max_size, max_size), Image.LANCZOS) | |
return image | |
def esrgan_upscale(image, model, device='cuda', upscale_factor=4): | |
"""Upscale image using ESRGAN with variable factor support""" | |
orig_w, orig_h = image.size | |
pre_resize_factor = upscale_factor / 4.0 | |
low_res_w = int(orig_w * pre_resize_factor) | |
low_res_h = int(orig_h * pre_resize_factor) | |
if low_res_w < 1 or low_res_h < 1: | |
raise ValueError("Upscale factor too small for image size") | |
low_res_image = image.resize((low_res_w, low_res_h), Image.BICUBIC) # Changed to BICUBIC for better match to training degradation | |
# Prepare image | |
img_np = np.array(low_res_image).astype(np.float32) / 255. | |
img_np = np.transpose(img_np, (2, 0, 1)) # HWC to CHW | |
img_tensor = torch.from_numpy(img_np).unsqueeze(0).to(device) | |
# Upscale | |
with torch.no_grad(): | |
output = model(img_tensor) | |
output = output.squeeze(0).cpu().clamp(0, 1) | |
output_np = output.numpy() | |
output_np = np.transpose(output_np, (1, 2, 0)) # CHW to HWC | |
output_np = (output_np * 255).astype(np.uint8) | |
upscaled = Image.fromarray(output_np) | |
# Resize back to exact target size if needed (due to rounding) | |
target_w = int(orig_w * upscale_factor) | |
target_h = int(orig_h * upscale_factor) | |
if upscaled.size != (target_w, target_h): | |
upscaled = upscaled.resize((target_w, target_h), Image.BICUBIC) # Changed to BICUBIC | |
return upscaled | |
# Increased to 120 seconds | |
def enhance_image( | |
input_image, | |
prompt, | |
seed, | |
randomize_seed, | |
num_inference_steps, | |
denoising_strength, | |
upscale_factor, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
"""Main enhancement function""" | |
if input_image is None: | |
raise gr.Error("Please upload an image") | |
# Clear memory | |
torch.cuda.empty_cache() | |
gc.collect() | |
try: | |
# Randomize seed if needed | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
# Prepare and validate prompt | |
prompt = truncate_prompt(prompt.strip() if prompt else "high quality, detailed") | |
# Prepare input image | |
input_image = prepare_image(input_image) | |
original_size = input_image.size | |
# Step 1: ESRGAN upscale on GPU | |
gr.Info(f"π Upscaling with ESRGAN x{upscale_factor}...") | |
# Move ESRGAN to GPU for faster processing | |
esrgan_model.to("cuda") | |
upscaled_image = esrgan_upscale(input_image, esrgan_model, device="cuda", upscale_factor=upscale_factor) | |
# Move ESRGAN back to CPU to free memory | |
esrgan_model.to("cpu") | |
torch.cuda.empty_cache() | |
# Ensure dimensions are multiples of 16 for FLUX | |
w, h = upscaled_image.size | |
new_w = make_multiple_16(w) | |
new_h = make_multiple_16(h) | |
if new_w != w or new_h != h: | |
# Pad image to meet requirements | |
padded = Image.new('RGB', (new_w, new_h)) | |
padded.paste(upscaled_image, (0, 0)) | |
upscaled_image = padded | |
# Step 2: FLUX enhancement | |
gr.Info("π¨ Enhancing with FLUX...") | |
# Move pipeline to GPU | |
pipe.to("cuda") | |
# Generate with FLUX | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
with torch.inference_mode(): | |
result = pipe( | |
prompt=prompt, | |
image=upscaled_image, | |
strength=denoising_strength, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=3.5, # Recommended for FLUX.1-dev to reduce artifacts | |
height=new_h, | |
width=new_w, | |
generator=generator, | |
).images[0] | |
# Crop back if we padded | |
if new_w != w or new_h != h: | |
result = result.crop((0, 0, w, h)) | |
# Move pipeline back to CPU | |
pipe.to("cpu") | |
torch.cuda.empty_cache() | |
gc.collect() | |
# Prepare images for slider (before/after) | |
input_resized = input_image.resize(result.size, Image.LANCZOS) | |
gr.Info("β Enhancement complete!") | |
return [input_resized, result], seed | |
except Exception as e: | |
# Cleanup on error | |
pipe.to("cpu") | |
esrgan_model.to("cpu") | |
torch.cuda.empty_cache() | |
gc.collect() | |
raise gr.Error(f"Enhancement failed: {str(e)}") | |
# Create Gradio interface | |
with gr.Blocks(css=css) as demo: | |
gr.HTML(""" | |
<div class="main-header"> | |
<h1>π Flux Dev Ultimate Upscaler</h1> | |
<p>Upload an image to upscale 2-4x with ESRGAN and enhance with FLUX</p> | |
<p>Optimized for <strong>ZeroGPU</strong> | Max input: 512x512 β Output: up to 2048x2048</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Input section | |
input_image = gr.Image( | |
label="Input Image", | |
type="pil", | |
height=256 | |
) | |
prompt = gr.Textbox( | |
label="Describe image with prompt", | |
placeholder="Describe the desired enhancement (e.g., 'high quality, sharp details, vibrant colors')", | |
value="high quality, ultra detailed, sharp", | |
lines=2 | |
) | |
# Advanced Settings (always open) | |
upscale_factor = gr.Slider( | |
label="Upscale Ratio", | |
minimum=2, | |
maximum=4, | |
step=1, | |
value=4, | |
info="Choose upscale factor (2x, 3x, 4x). Use 4x for best results; lower may cause color artifacts." | |
) | |
num_inference_steps = gr.Slider( | |
label="Enhancement Steps", | |
minimum=10, | |
maximum=25, | |
step=1, | |
value=20, # Increased default for better denoising | |
info="More steps = better quality but slower" | |
) | |
denoising_strength = gr.Slider( | |
label="Creativity (Denoising)", | |
minimum=0.1, | |
maximum=0.6, | |
step=0.05, | |
value=0.35, | |
info="Higher = more changes to the image" | |
) | |
with gr.Row(): | |
randomize_seed = gr.Checkbox( | |
label="Randomize seed", | |
value=True | |
) | |
seed = gr.Number( | |
label="Seed", | |
value=42 | |
) | |
enhance_btn = gr.Button( | |
"Upscale", | |
variant="primary", | |
size="lg" | |
) | |
with gr.Column(scale=2): | |
# Output section | |
result_slider = ImageSlider( | |
type="pil", | |
label="Before / After", | |
interactive=False, | |
height=512 | |
) | |
used_seed = gr.Number( | |
label="Seed Used", | |
interactive=False, | |
visible=False | |
) | |
# Event handler | |
enhance_btn.click( | |
fn=enhance_image, | |
inputs=[ | |
input_image, | |
prompt, | |
seed, | |
randomize_seed, | |
num_inference_steps, | |
denoising_strength, | |
upscale_factor, | |
], | |
outputs=[result_slider, used_seed] | |
) | |
gr.HTML(""" | |
<div style="margin-top: 2rem; text-align: center; color: #666;"> | |
<p>π Pipeline: ESRGAN 2-4x-UltraSharp β FLUX Dev Enhancement</p> | |
<p>β‘ Optimized for ZeroGPU with automatic memory management</p> | |
<p>π Note: User is responsible for obtaining commercial license from Flux Dev if using image commercially under their license.</p> | |
</div> | |
""") | |
if __name__ == "__main__": | |
demo.queue(max_size=3).launch( | |
share=False, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) |