Spaces:
Running
on
Zero
Running
on
Zero
File size: 13,815 Bytes
a49f337 93af3e2 b7cfbcf a49f337 c1ad781 597b21e 93af3e2 74168bc 93af3e2 58d1893 74168bc 597b21e 93af3e2 597b21e 93af3e2 a49f337 93af3e2 a49f337 93af3e2 a49f337 93af3e2 b0a0a29 93af3e2 b0a0a29 93af3e2 a49f337 93af3e2 a49f337 74168bc a49f337 74168bc a49f337 b0a0a29 a49f337 b0a0a29 a49f337 b0a9f3e a49f337 b0a9f3e 93af3e2 58d1893 a49f337 93af3e2 a49f337 b0a9f3e a49f337 b0a9f3e a49f337 b0a9f3e 93af3e2 9047ade c2d1871 9047ade 74168bc 9047ade 74168bc a49f337 74168bc 9047ade c2d1871 9047ade 93af3e2 9047ade 93af3e2 a49f337 93af3e2 9047ade 93af3e2 a49f337 b0a9f3e a49f337 b0a9f3e a49f337 b0a9f3e a49f337 b0a9f3e 9047ade 74168bc 9047ade 74168bc b0a9f3e a49f337 b0a9f3e a49f337 b0a9f3e a49f337 b0a9f3e a49f337 b0a9f3e a49f337 b0a0a29 a49f337 3bb8a2e a49f337 b0a9f3e a49f337 b0a9f3e a49f337 b0a9f3e b0a0a29 b0a9f3e a49f337 b0a9f3e a49f337 b0a9f3e a49f337 b0a9f3e b0a0a29 a49f337 93af3e2 a49f337 b0a0a29 a49f337 93af3e2 9047ade 93af3e2 b0a9f3e b0a0a29 93af3e2 a49f337 93af3e2 a1ef78c a49f337 9047ade a49f337 93af3e2 a1ef78c 9047ade c2d1871 9047ade 3bb8a2e 9047ade a49f337 9047ade 93af3e2 a49f337 9047ade a49f337 b0a0a29 b0a9f3e a49f337 93af3e2 a49f337 b0a9f3e a49f337 93af3e2 b0a0a29 a49f337 93af3e2 a49f337 9047ade 93af3e2 a49f337 93af3e2 a1ef78c 93af3e2 a49f337 9047ade a49f337 9047ade 93af3e2 c1ad781 93af3e2 b0a0a29 a49f337 b0a0a29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 |
import os
import random
import warnings
import gc
import gradio as gr
import numpy as np
import spaces
import torch
import torch.nn as nn
from diffusers import FluxImg2ImgPipeline
from gradio_imageslider import ImageSlider
from PIL import Image
from huggingface_hub import snapshot_download
import requests
# Minimal ESRGAN implementation (without basicsr dependency)
class ResidualDenseBlock(nn.Module):
def __init__(self, num_feat=64, num_grow_ch=32):
super(ResidualDenseBlock, self).__init__()
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(nn.Module):
def __init__(self, num_feat, num_grow_ch=32):
super(RRDB, self).__init__()
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
def forward(self, x):
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
return out * 0.2 + x
class RRDBNet(nn.Module):
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4):
super(RRDBNet, self).__init__()
self.scale = scale
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.body = nn.Sequential(*[RRDB(num_feat, num_grow_ch) for _ in range(num_block)])
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
# Upsampling
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.conv_first(x)
trunk = self.conv_body(self.body(fea))
fea = fea + trunk
fea = self.lrelu(self.conv_up1(nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
fea = self.lrelu(self.conv_up2(nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.conv_hr(fea)))
return out
css = """
#col-container {
margin: 0 auto;
max-width: 800px;
}
.main-header {
text-align: center;
margin-bottom: 2rem;
}
"""
# Get HuggingFace token
huggingface_token = os.getenv("HF_TOKEN")
# Download FLUX model if not already cached
print("π₯ Downloading FLUX model...")
model_path = snapshot_download(
repo_id="black-forest-labs/FLUX.1-dev",
repo_type="model",
ignore_patterns=["*.md", "*.gitattributes"],
local_dir="FLUX.1-dev",
token=huggingface_token,
)
# Load FLUX pipeline on CPU initially
print("π₯ Loading FLUX Img2Img pipeline...")
pipe = FluxImg2ImgPipeline.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
use_safetensors=True
)
# Enable memory optimizations
pipe.enable_vae_tiling()
pipe.enable_vae_slicing()
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
# Download and load ESRGAN 4x-UltraSharp model
print("π₯ Loading ESRGAN 4x-UltraSharp...")
esrgan_path = "4x-UltraSharp.pth"
if not os.path.exists(esrgan_path):
print("Downloading ESRGAN model...")
url = "https://huggingface.co/uwg/upscaler/resolve/main/ESRGAN/4x-UltraSharp.pth"
response = requests.get(url)
with open(esrgan_path, "wb") as f:
f.write(response.content)
# Initialize ESRGAN model
esrgan_model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4
)
# Load state dict
state_dict = torch.load(esrgan_path, map_location='cpu')
if 'params_ema' in state_dict:
state_dict = state_dict['params_ema']
elif 'params' in state_dict:
state_dict = state_dict['params']
# Clean state dict keys if needed
cleaned_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
cleaned_state_dict[k[7:]] = v
else:
cleaned_state_dict[k] = v
esrgan_model.load_state_dict(cleaned_state_dict, strict=False)
esrgan_model.eval()
print("β
All models loaded successfully!")
MAX_SEED = 1000000
MAX_INPUT_SIZE = 512 # Max input size before upscaling
def make_multiple_16(n):
"""Round to nearest multiple of 16 for FLUX requirements"""
return ((n + 15) // 16) * 16
def truncate_prompt(prompt, max_tokens=75):
"""Truncate prompt to avoid CLIP token limit (77 tokens)"""
if not prompt:
return ""
# Simple truncation by character count (rough approximation)
if len(prompt) > 250: # ~75 tokens
return prompt[:250] + "..."
return prompt
def prepare_image(image, max_size=MAX_INPUT_SIZE):
"""Prepare image for processing"""
w, h = image.size
# Limit input size
if w > max_size or h > max_size:
image.thumbnail((max_size, max_size), Image.LANCZOS)
return image
def esrgan_upscale(image, model, device='cuda', upscale_factor=4):
"""Upscale image using ESRGAN with variable factor support"""
orig_w, orig_h = image.size
pre_resize_factor = upscale_factor / 4.0
low_res_w = int(orig_w * pre_resize_factor)
low_res_h = int(orig_h * pre_resize_factor)
if low_res_w < 1 or low_res_h < 1:
raise ValueError("Upscale factor too small for image size")
low_res_image = image.resize((low_res_w, low_res_h), Image.BICUBIC) # Changed to BICUBIC for better match to training degradation
# Prepare image
img_np = np.array(low_res_image).astype(np.float32) / 255.
img_np = np.transpose(img_np, (2, 0, 1)) # HWC to CHW
img_tensor = torch.from_numpy(img_np).unsqueeze(0).to(device)
# Upscale
with torch.no_grad():
output = model(img_tensor)
output = output.squeeze(0).cpu().clamp(0, 1)
output_np = output.numpy()
output_np = np.transpose(output_np, (1, 2, 0)) # CHW to HWC
output_np = (output_np * 255).astype(np.uint8)
upscaled = Image.fromarray(output_np)
# Resize back to exact target size if needed (due to rounding)
target_w = int(orig_w * upscale_factor)
target_h = int(orig_h * upscale_factor)
if upscaled.size != (target_w, target_h):
upscaled = upscaled.resize((target_w, target_h), Image.BICUBIC) # Changed to BICUBIC
return upscaled
@spaces.GPU(duration=120) # Increased to 120 seconds
def enhance_image(
input_image,
prompt,
seed,
randomize_seed,
num_inference_steps,
denoising_strength,
upscale_factor,
progress=gr.Progress(track_tqdm=True),
):
"""Main enhancement function"""
if input_image is None:
raise gr.Error("Please upload an image")
# Clear memory
torch.cuda.empty_cache()
gc.collect()
try:
# Randomize seed if needed
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Prepare and validate prompt
prompt = truncate_prompt(prompt.strip() if prompt else "high quality, detailed")
# Prepare input image
input_image = prepare_image(input_image)
original_size = input_image.size
# Step 1: ESRGAN upscale on GPU
gr.Info(f"π Upscaling with ESRGAN x{upscale_factor}...")
# Move ESRGAN to GPU for faster processing
esrgan_model.to("cuda")
upscaled_image = esrgan_upscale(input_image, esrgan_model, device="cuda", upscale_factor=upscale_factor)
# Move ESRGAN back to CPU to free memory
esrgan_model.to("cpu")
torch.cuda.empty_cache()
# Ensure dimensions are multiples of 16 for FLUX
w, h = upscaled_image.size
new_w = make_multiple_16(w)
new_h = make_multiple_16(h)
if new_w != w or new_h != h:
# Pad image to meet requirements
padded = Image.new('RGB', (new_w, new_h))
padded.paste(upscaled_image, (0, 0))
upscaled_image = padded
# Step 2: FLUX enhancement
gr.Info("π¨ Enhancing with FLUX...")
# Move pipeline to GPU
pipe.to("cuda")
# Generate with FLUX
generator = torch.Generator(device="cuda").manual_seed(seed)
with torch.inference_mode():
result = pipe(
prompt=prompt,
image=upscaled_image,
strength=denoising_strength,
num_inference_steps=num_inference_steps,
guidance_scale=3.5, # Recommended for FLUX.1-dev to reduce artifacts
height=new_h,
width=new_w,
generator=generator,
).images[0]
# Crop back if we padded
if new_w != w or new_h != h:
result = result.crop((0, 0, w, h))
# Move pipeline back to CPU
pipe.to("cpu")
torch.cuda.empty_cache()
gc.collect()
# Prepare images for slider (before/after)
input_resized = input_image.resize(result.size, Image.LANCZOS)
gr.Info("β
Enhancement complete!")
return [input_resized, result], seed
except Exception as e:
# Cleanup on error
pipe.to("cpu")
esrgan_model.to("cpu")
torch.cuda.empty_cache()
gc.collect()
raise gr.Error(f"Enhancement failed: {str(e)}")
# Create Gradio interface
with gr.Blocks(css=css) as demo:
gr.HTML("""
<div class="main-header">
<h1>π Flux Dev Ultimate Upscaler</h1>
<p>Upload an image to upscale 2-4x with ESRGAN and enhance with FLUX</p>
<p>Optimized for <strong>ZeroGPU</strong> | Max input: 512x512 β Output: up to 2048x2048</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
# Input section
input_image = gr.Image(
label="Input Image",
type="pil",
height=256
)
prompt = gr.Textbox(
label="Describe image with prompt",
placeholder="Describe the desired enhancement (e.g., 'high quality, sharp details, vibrant colors')",
value="high quality, ultra detailed, sharp",
lines=2
)
# Advanced Settings (always open)
upscale_factor = gr.Slider(
label="Upscale Ratio",
minimum=2,
maximum=4,
step=1,
value=4,
info="Choose upscale factor (2x, 3x, 4x). Use 4x for best results; lower may cause color artifacts."
)
num_inference_steps = gr.Slider(
label="Enhancement Steps",
minimum=10,
maximum=25,
step=1,
value=20, # Increased default for better denoising
info="More steps = better quality but slower"
)
denoising_strength = gr.Slider(
label="Creativity (Denoising)",
minimum=0.1,
maximum=0.6,
step=0.05,
value=0.35,
info="Higher = more changes to the image"
)
with gr.Row():
randomize_seed = gr.Checkbox(
label="Randomize seed",
value=True
)
seed = gr.Number(
label="Seed",
value=42
)
enhance_btn = gr.Button(
"Upscale",
variant="primary",
size="lg"
)
with gr.Column(scale=2):
# Output section
result_slider = ImageSlider(
type="pil",
label="Before / After",
interactive=False,
height=512
)
used_seed = gr.Number(
label="Seed Used",
interactive=False,
visible=False
)
# Event handler
enhance_btn.click(
fn=enhance_image,
inputs=[
input_image,
prompt,
seed,
randomize_seed,
num_inference_steps,
denoising_strength,
upscale_factor,
],
outputs=[result_slider, used_seed]
)
gr.HTML("""
<div style="margin-top: 2rem; text-align: center; color: #666;">
<p>π Pipeline: ESRGAN 2-4x-UltraSharp β FLUX Dev Enhancement</p>
<p>β‘ Optimized for ZeroGPU with automatic memory management</p>
<p>π Note: User is responsible for obtaining commercial license from Flux Dev if using image commercially under their license.</p>
</div>
""")
if __name__ == "__main__":
demo.queue(max_size=3).launch(
share=False,
server_name="0.0.0.0",
server_port=7860
) |