fluxhdupscaler / app.py
comrender's picture
Update app.py
857d418 verified
import logging
import random
import warnings
import os
import gradio as gr
import numpy as np
import spaces
import torch
from diffusers import FluxImg2ImgPipeline
from gradio_imageslider import ImageSlider
from PIL import Image
from huggingface_hub import snapshot_download
import requests
from transformers import T5TokenizerFast
# For ESRGAN (requires pip install basicsr gfpgan)
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils import img2tensor, tensor2img
USE_ESRGAN = True
except ImportError:
USE_ESRGAN = False
warnings.warn("basicsr not installed; falling back to LANCZOS interpolation.")
css = """
#col-container {
margin: 0 auto;
max-width: 800px;
}
.main-header {
text-align: center;
margin-bottom: 2rem;
}
"""
# Device setup - Default to CPU, let runtime handle GPU
power_device = "ZeroGPU"
device = "cpu"
# Get HuggingFace token
huggingface_token = os.getenv("HF_TOKEN")
MAX_SEED = 1000000
MAX_PIXEL_BUDGET = 8192 * 8192 # Increased for tiling support
def make_divisible_by_16(size):
"""Adjust size to nearest multiple of 16, stretching if necessary"""
return ((size // 16) * 16) if (size % 16) < 8 else ((size // 16 + 1) * 16)
def process_input(input_image, upscale_factor):
"""Process input image and handle size constraints"""
w, h = input_image.size
w_original, h_original = w, h
aspect_ratio = w / h
was_resized = False
if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
warnings.warn(
f"Requested output image is too large ({w * upscale_factor}x{h * upscale_factor}). Resizing to fit budget."
)
gr.Info(
f"Requested output image is too large. Resizing input to fit within pixel budget."
)
target_input_pixels = MAX_PIXEL_BUDGET / (upscale_factor ** 2)
scale = (target_input_pixels / (w * h)) ** 0.5
new_w = int(w * scale) // 16 * 16 # Ensure divisible by 16 for Flux compatibility
new_h = int(h * scale) // 16 * 16
if new_w == 0 or new_h == 0:
new_w = max(16, new_w)
new_h = max(16, new_h)
input_image = input_image.resize((new_w, new_h), resample=Image.LANCZOS)
was_resized = True
return input_image, w_original, h_original, was_resized
def load_image_from_url(url):
"""Load image from URL"""
try:
response = requests.get(url, stream=True)
response.raise_for_status()
return Image.open(response.raw)
except Exception as e:
raise gr.Error(f"Failed to load image from URL: {e}")
def esrgan_upscale(image, scale=4):
if not USE_ESRGAN:
return image.resize((image.width * scale, image.height * scale), resample=Image.LANCZOS)
img = img2tensor(np.array(image) / 255., bgr2rgb=False, float32=True)
with torch.no_grad():
output = esrgan_model(img.unsqueeze(0)).squeeze()
output_img = tensor2img(output, rgb2bgr=False, min_max=(0, 1))
return Image.fromarray(output_img)
def tiled_flux_img2img(pipe, prompt, image, strength, steps, guidance, generator, tile_size=1024, overlap=32):
"""Tiled Img2Img to mimic Ultimate SD Upscaler tiling"""
w, h = image.size
output = image.copy() # Start with the control image
for x in range(0, w, tile_size - overlap):
for y in range(0, h, tile_size - overlap):
tile_w = min(tile_size, w - x)
tile_h = min(tile_size, h - y)
if tile_h < 16 or tile_w < 16: # Skip tiny tiles
continue
tile = image.crop((x, y, x + tile_w, y + tile_h))
# Force tile to div by 16
new_tile_w = make_divisible_by_16(tile_w)
new_tile_h = make_divisible_by_16(tile_h)
tile = tile.resize((new_tile_w, new_tile_h), resample=Image.LANCZOS)
# Run Flux on tile
gen_tile = pipe(
prompt=prompt,
image=tile,
strength=strength,
num_inference_steps=steps,
guidance_scale=guidance,
height=new_tile_h,
width=new_tile_w,
generator=generator,
).images[0]
# Resize gen_tile back to original tile dimensions
gen_tile = gen_tile.resize((tile_w, tile_h), resample=Image.LANCZOS)
# Paste with blending if overlap
if overlap > 0:
paste_box = (x, y, x + tile_w, y + tile_h)
if x > 0 or y > 0:
# Simple linear blend on overlaps
mask = Image.new('L', (tile_w, tile_h), 255)
effective_overlap_x = min(overlap, tile_w)
effective_overlap_y = min(overlap, tile_h)
if x > 0:
for i in range(effective_overlap_x):
for j in range(tile_h):
mask.putpixel((i, j), int(255 * (i / overlap)))
if y > 0:
for i in range(tile_w):
for j in range(effective_overlap_y):
mask.putpixel((i, j), int(255 * (j / overlap)))
output.paste(gen_tile, paste_box, mask)
else:
output.paste(gen_tile, paste_box)
else:
output.paste(gen_tile, (x, y))
return output
@spaces.GPU(duration=120)
def enhance_image(
image_input,
image_url,
seed,
randomize_seed,
num_inference_steps,
upscale_factor,
denoising_strength,
custom_prompt,
tile_size,
progress=gr.Progress(track_tqdm=True),
):
"""Main enhancement function"""
# Lazy loading of models
global pipe, esrgan_model
if 'pipe' not in globals():
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if device == "cuda" else torch.float32
print(f"πŸ“₯ Loading FLUX Img2Img on {device}...")
tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2", token=huggingface_token)
pipe = FluxImg2ImgPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=dtype,
low_cpu_mem_usage=True,
device_map="balanced",
tokenizer_2=tokenizer_2,
token=huggingface_token
)
pipe.enable_vae_tiling()
pipe.enable_vae_slicing()
if device == "cuda":
pipe.reset_device_map()
pipe.enable_model_cpu_offload()
if USE_ESRGAN:
esrgan_path = "4x-UltraSharp.pth"
if not os.path.exists(esrgan_path):
url = "https://huggingface.co/uwg/upscaler/resolve/main/ESRGAN/4x-UltraSharp.pth"
with open(esrgan_path, "wb") as f:
f.write(requests.get(url).content)
esrgan_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
state_dict = torch.load(esrgan_path)['params_ema']
esrgan_model.load_state_dict(state_dict)
esrgan_model.eval()
esrgan_model.to(device)
print("βœ… Models loaded successfully!")
except Exception as e:
print(f"Model loading error: {e}, falling back to CPU")
device = "cpu"
dtype = torch.float32
# Reload on CPU if needed
tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2", token=huggingface_token)
pipe = FluxImg2ImgPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=dtype,
low_cpu_mem_usage=True,
device_map=None,
tokenizer_2=tokenizer_2,
token=huggingface_token
)
pipe.enable_vae_tiling()
pipe.enable_vae_slicing()
# Handle image input
if image_input is not None:
input_image = image_input
elif image_url:
input_image = load_image_from_url(image_url)
else:
raise gr.Error("Please provide an image (upload or URL)")
if randomize_seed:
seed = random.randint(0, MAX_SEED)
true_input_image = input_image
# Process input image
input_image, w_original, h_original, was_resized = process_input(
input_image, upscale_factor
)
prompt = custom_prompt if custom_prompt.strip() else ""
generator = torch.Generator(device=device).manual_seed(seed)
gr.Info("πŸš€ Upscaling image...")
# Initial upscale
if USE_ESRGAN and upscale_factor == 4:
control_image = esrgan_upscale(input_image, upscale_factor)
else:
w, h = input_image.size
control_image = input_image.resize((w * upscale_factor, h * upscale_factor), resample=Image.LANCZOS)
# Resize control_image to divisible by 16 (stretching)
control_w, control_h = control_image.size
new_control_w = make_divisible_by_16(control_w)
new_control_h = make_divisible_by_16(control_h)
if (new_control_w, new_control_h) != (control_w, control_h):
control_image = control_image.resize((new_control_w, new_control_h), resample=Image.LANCZOS)
# Tiled Flux Img2Img for refinement
image = tiled_flux_img2img(
pipe,
prompt,
control_image,
denoising_strength,
num_inference_steps,
3.5, # Updated guidance_scale to match workflow (3.5)
generator,
tile_size=tile_size,
overlap=32
)
# Resize back to original target size if stretched
target_w, target_h = w_original * upscale_factor, h_original * upscale_factor
if image.size != (target_w, target_h):
image = image.resize((target_w, target_h), resample=Image.LANCZOS)
if was_resized:
gr.Info(f"πŸ“ Resizing output to target size: {target_w}x{target_h}")
image = image.resize((target_w, target_h), resample=Image.LANCZOS)
# Resize input image to match output size for slider alignment
resized_input = true_input_image.resize(image.size, resample=Image.LANCZOS)
# Move back to CPU to release GPU if possible
if device == "cuda":
pipe.to("cpu")
if USE_ESRGAN:
esrgan_model.to("cpu")
return [resized_input, image]
# Create Gradio interface
with gr.Blocks(css=css, title="🎨 AI Image Upscaler - FLUX") as demo:
gr.HTML("""
<div class="main-header">
<h1>🎨 AI Image Upscaler</h1>
<p>Upload an image or provide a URL to upscale it using FLUX upscaling</p>
<p>Currently running on <strong>{}</strong></p>
</div>
""".format(power_device))
with gr.Row():
with gr.Column(scale=1):
gr.HTML("<h3>πŸ“€ Input</h3>")
with gr.Tabs():
with gr.TabItem("πŸ“ Upload Image"):
input_image = gr.Image(
label="Upload Image",
type="pil",
height=200 # Made smaller
)
with gr.TabItem("πŸ”— Image URL"):
image_url = gr.Textbox(
label="Image URL",
placeholder="https://example.com/image.jpg",
value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg"
)
gr.HTML("<h3>πŸŽ›οΈ Prompt Settings</h3>")
custom_prompt = gr.Textbox(
label="Custom Prompt (optional)",
placeholder="Enter custom prompt or leave empty",
lines=2
)
gr.HTML("<h3>βš™οΈ Upscaling Settings</h3>")
upscale_factor = gr.Slider(
label="Upscale Factor",
minimum=1,
maximum=4,
step=1,
value=2,
info="How much to upscale the image"
)
num_inference_steps = gr.Slider(
label="Number of Inference Steps",
minimum=1,
maximum=50,
step=1,
value=4,
info="More steps = better quality but slower (default 4 for schnell)"
)
denoising_strength = gr.Slider(
label="Denoising Strength",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.3,
info="Controls how much the image is transformed"
)
tile_size = gr.Slider(
label="Tile Size",
minimum=256,
maximum=2048,
step=64,
value=1024,
info="Size of tiles for processing (larger = faster but more memory)"
)
with gr.Row():
randomize_seed = gr.Checkbox(
label="Randomize seed",
value=True
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
interactive=True
)
enhance_btn = gr.Button(
"πŸš€ Upscale Image",
variant="primary",
size="lg"
)
with gr.Column(scale=2): # Larger scale for results
gr.HTML("<h3>πŸ“Š Results</h3>")
result_slider = ImageSlider(
type="pil",
interactive=False, # Disable interactivity to prevent uploads
height=600, # Made larger
elem_id="result_slider",
label=None # Remove default label
)
# Event handler
enhance_btn.click(
fn=enhance_image,
inputs=[
input_image,
image_url,
seed,
randomize_seed,
num_inference_steps,
upscale_factor,
denoising_strength,
custom_prompt,
tile_size
],
outputs=[result_slider]
)
gr.HTML("""
<div style="margin-top: 2rem; padding: 1rem; background: #f0f0f0; border-radius: 8px;">
<p><strong>Note:</strong> This upscaler uses the Flux.1-schnell model. Users are responsible for obtaining commercial rights if used commercially under their license.</p>
</div>
""")
# Custom CSS for slider
gr.HTML("""
<style>
#result_slider .slider {
width: 100% !important;
max-width: inherit !important;
}
#result_slider img {
object-fit: contain !important;
width: 100% !important;
height: auto !important;
}
#result_slider .gr-button-tool {
display: none !important;
}
#result_slider .gr-button-undo {
display: none !important;
}
#result_slider .gr-button-clear {
display: none !important;
}
#result_slider .badge-container .badge {
display: none !important;
}
#result_slider .badge-container::before {
content: "Before";
position: absolute;
top: 10px;
left: 10px;
background: rgba(0,0,0,0.5);
color: white;
padding: 5px;
border-radius: 5px;
z-index: 10;
}
#result_slider .badge-container::after {
content: "After";
position: absolute;
top: 10px;
right: 10px;
background: rgba(0,0,0,0.5);
color: white;
padding: 5px;
border-radius: 5px;
z-index: 10;
}
#result_slider .fullscreen img {
object-fit: contain !important;
width: 100vw !important;
height: 100vh !important;
position: absolute;
top: 0;
left: 0;
}
</style>
""")
# JS to set slider default position to middle
gr.HTML("""
<script>
document.addEventListener('DOMContentLoaded', function() {
const sliderInput = document.querySelector('#result_slider input[type="range"]');
if (sliderInput) {
sliderInput.value = 50;
sliderInput.dispatchEvent(new Event('input'));
}
});
</script>
""")
if __name__ == "__main__":
demo.queue().launch(share=True, server_name="0.0.0.0", server_port=7860)