fluxhdupscaler / app.py
comrender's picture
Update app.py
23dd7dc verified
raw
history blame
17.8 kB
import logging
import random
import warnings
import os
import gradio as gr
import numpy as np
import spaces
import torch
from diffusers import FluxImg2ImgPipeline
from transformers import AutoProcessor, AutoModelForCausalLM
from gradio_imageslider import ImageSlider
from PIL import Image
from huggingface_hub import snapshot_download
import requests
import io
import base64
# For ESRGAN (requires pip install basicsr gfpgan)
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils import img2tensor, tensor2img
USE_ESRGAN = True
except ImportError:
USE_ESRGAN = False
warnings.warn("basicsr not installed; falling back to LANCZOS interpolation.")
css = """
#col-container {
margin: 0 auto;
max-width: 800px;
}
.main-header {
text-align: center;
margin-bottom: 2rem;
}
"""
# Device setup
if torch.cuda.is_available():
power_device = "GPU"
device = "cuda"
else:
power_device = "CPU"
device = "cpu"
# Get HuggingFace token
huggingface_token = os.getenv("HF_TOKEN")
# Download FLUX model
print("πŸ“₯ Downloading FLUX model...")
model_path = snapshot_download(
repo_id="black-forest-labs/FLUX.1-dev",
repo_type="model",
ignore_patterns=["*.md", "*.gitattributes"],
local_dir="FLUX.1-dev",
token=huggingface_token,
)
# Load Florence-2 model for image captioning
print("πŸ“₯ Loading Florence-2 model...")
florence_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-large",
torch_dtype=torch.float16,
trust_remote_code=True,
attn_implementation="eager"
).to(device)
florence_processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-large",
trust_remote_code=True
)
# Load FLUX Img2Img pipeline
print("πŸ“₯ Loading FLUX Img2Img...")
pipe = FluxImg2ImgPipeline.from_pretrained(
model_path,
torch_dtype=torch.bfloat16
)
pipe.to(device)
pipe.enable_vae_tiling()
pipe.enable_vae_slicing()
print("βœ… All models loaded successfully!")
# Download ESRGAN model if using
if USE_ESRGAN:
esrgan_path = "4x-UltraSharp.pth"
if not os.path.exists(esrgan_path):
url = "https://huggingface.co/uwg/upscaler/resolve/main/ESRGAN/4x-UltraSharp.pth"
with open(esrgan_path, "wb") as f:
f.write(requests.get(url).content)
esrgan_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
state_dict = torch.load(esrgan_path)['params_ema']
esrgan_model.load_state_dict(state_dict)
esrgan_model.eval()
esrgan_model.to(device)
MAX_SEED = 1000000
MAX_PIXEL_BUDGET = 8192 * 8192
def generate_caption(image):
"""Generate detailed caption using Florence-2"""
try:
task_prompt = "<MORE_DETAILED_CAPTION>"
prompt = task_prompt
inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(device)
inputs["pixel_values"] = inputs["pixel_values"].to(torch.float16)
generated_ids = florence_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3,
do_sample=True,
)
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = florence_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
caption = parsed_answer[task_prompt]
return caption
except Exception as e:
print(f"Caption generation failed: {e}")
return "a high quality detailed image"
def process_input(input_image, upscale_factor):
"""Process input image and handle size constraints"""
w, h = input_image.size
w_original, h_original = w, h
was_resized = False
if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
warnings.warn(
f"Requested output image is too large ({w * upscale_factor}x{h * upscale_factor}). Resizing to fit budget."
)
gr.Info(
f"Requested output image is too large. Resizing input to fit within pixel budget."
)
target_input_pixels = MAX_PIXEL_BUDGET / (upscale_factor ** 2)
scale = (target_input_pixels / (w * h)) ** 0.5
new_w = int(w * scale) - int(w * scale) % 8
new_h = int(h * scale) - int(h * scale) % 8
input_image = input_image.resize((new_w, new_h), resample=Image.LANCZOS)
was_resized = True
return input_image, w_original, h_original, was_resized
def load_image_from_url(url):
"""Load image from URL and convert to PNG"""
try:
response = requests.get(url, stream=True)
response.raise_for_status()
img = Image.open(response.raw)
buffer = io.BytesIO()
img.save(buffer, format="PNG")
buffer.seek(0)
return Image.open(buffer)
except Exception as e:
raise gr.Error(f"Failed to load image from URL: {e}")
def esrgan_upscale(image, scale=4):
if not USE_ESRGAN:
return image.resize((image.width * scale, image.height * scale), resample=Image.LANCZOS)
img = img2tensor(np.array(image) / 255., bgr2rgb=False, float32=True)
with torch.no_grad():
output = esrgan_model(img.unsqueeze(0)).squeeze()
output_img = tensor2img(output, rgb2bgr=False, min_max=(0, 1))
return Image.fromarray(output_img)
def tiled_flux_img2img(pipe, prompt, image, strength, steps, guidance, generator, tile_size=1024, overlap=32):
"""Tiled Img2Img to mimic Ultimate SD Upscaler tiling"""
w, h = image.size
output = image.copy()
max_clip_tokens = pipe.tokenizer.model_max_length
input_ids = pipe.tokenizer.encode(prompt, return_tensors="pt")
if input_ids.shape[1] > max_clip_tokens:
input_ids = input_ids[:, :max_clip_tokens]
prompt_clip = pipe.tokenizer.decode(input_ids[0], skip_special_tokens=True)
else:
prompt_clip = prompt
for x in range(0, w, tile_size - overlap):
for y in range(0, h, tile_size - overlap):
tile_w = min(tile_size, w - x)
tile_h = min(tile_size, h - y)
tile = image.crop((x, y, x + tile_w, y + tile_h))
gen_tile = pipe(
prompt=prompt_clip,
prompt_2=prompt,
image=tile,
strength=strength,
num_inference_steps=steps,
guidance_scale=guidance,
height=tile_h,
width=tile_w,
generator=generator,
).images[0]
gen_tile = gen_tile.resize((tile_w, tile_h), resample=Image.LANCZOS)
if overlap > 0:
paste_box = (x, y, x + tile_w, y + tile_h)
if x > 0 or y > 0:
mask = Image.new('L', (tile_w, tile_h), 255)
if x > 0:
blend_width = min(overlap, tile_w)
for i in range(blend_width):
for j in range(tile_h):
mask.putpixel((i, j), int(255 * (i / overlap)))
if y > 0:
blend_height = min(overlap, tile_h)
for i in range(tile_w):
for j in range(blend_height):
mask.putpixel((i, j), int(255 * (j / overlap)))
output.paste(gen_tile, paste_box, mask)
else:
output.paste(gen_tile, paste_box)
else:
output.paste(gen_tile, (x, y))
return output
def download_png(image):
"""Convert image to PNG and return base64 string for download"""
if image is None:
raise gr.Error("No upscaled image available to download")
buffer = io.BytesIO()
image.save(buffer, format="PNG")
base64_data = base64.b64encode(buffer.getvalue()).decode('utf-8')
return base64_data
@spaces.GPU(duration=120)
def enhance_image(
image_input,
image_url,
randomize_seed,
num_inference_steps,
upscale_factor,
denoising_strength,
use_generated_caption,
custom_prompt,
progress=gr.Progress(track_tqdm=True),
):
"""Main enhancement function"""
if image_input is not None:
buffer = io.BytesIO()
image_input.save(buffer, format="PNG")
buffer.seek(0)
input_image = Image.open(buffer)
elif image_url:
input_image = load_image_from_url(image_url)
else:
raise gr.Error("Please provide an image (upload or URL)")
if randomize_seed:
seed = random.randint(0, MAX_SEED)
else:
seed = 42
true_input_image = input_image
input_image, w_original, h_original, was_resized = process_input(
input_image, upscale_factor
)
if use_generated_caption:
gr.Info("πŸ” Generating image caption...")
generated_caption = generate_caption(input_image)
prompt = generated_caption
else:
prompt = custom_prompt if custom_prompt.strip() else ""
generator = torch.Generator().manual_seed(seed)
gr.Info("πŸš€ Upscaling image...")
if USE_ESRGAN and upscale_factor == 4:
control_image = esrgan_upscale(input_image, upscale_factor)
else:
w, h = input_image.size
control_image = input_image.resize((w * upscale_factor, h * upscale_factor), resample=Image.LANCZOS)
image = tiled_flux_img2img(
pipe,
prompt,
control_image,
denoising_strength,
num_inference_steps,
1.0,
generator,
tile_size=1024,
overlap=32
)
if was_resized:
gr.Info(f"πŸ“ Resizing output to target size: {w_original * upscale_factor}x{h_original * upscale_factor}")
image = image.resize((w_original * upscale_factor, h_original * upscale_factor), resample=Image.LANCZOS)
resized_input = true_input_image.resize(image.size, resample=Image.LANCZOS)
return [resized_input, image], image
# Create Gradio interface
with gr.Blocks(css=css, title="🎨 Flux dev Creative Upscaler - Florence-2 + FLUX") as demo:
gr.HTML("""
<div class="main-header">
<h1>🎨 Flux dev Creative Upscaler</h1>
<p>Upload an image or provide a URL to upscale it using Florence-2 captioning and FLUX dev with Ultimate SD Upscaler</p>
<p>Currently running on <strong>{}</strong></p>
</div>
""".format(power_device))
with gr.Row():
with gr.Column(scale=1):
gr.HTML("<h3>πŸ“€ Input</h3>")
with gr.Tabs():
with gr.TabItem("πŸ“ Upload Image"):
input_image = gr.Image(
label="Upload Image",
type="pil",
height=200
)
with gr.TabItem("πŸ”— Image URL"):
image_url = gr.Textbox(
label="Image URL",
placeholder="https://example.com/image.jpg",
value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg"
)
gr.HTML("<h3>πŸŽ›οΈ Caption Settings</h3>")
use_generated_caption = gr.Checkbox(
label="Use AI-generated caption (Florence-2)",
value=True,
info="Generate detailed caption automatically"
)
custom_prompt = gr.Textbox(
label="Custom Prompt (optional)",
placeholder="Enter custom prompt or leave empty for generated caption",
lines=2
)
gr.HTML("<h3>βš™οΈ Upscaling Settings</h3>")
upscale_factor = gr.Slider(
label="Upscale Factor",
minimum=1,
maximum=4,
step=1,
value=2,
info="How much to upscale the image"
)
num_inference_steps = gr.Slider(
label="Steps (25 Recommended)",
minimum=8,
maximum=50,
step=1,
value=25,
info="More steps = better quality but slower"
)
denoising_strength = gr.Slider(
label="Creativity (Denoising)",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.3,
info="Controls how much the image is transformed"
)
with gr.Row():
randomize_seed = gr.Checkbox(
label="Randomize seed",
value=True
)
enhance_btn = gr.Button(
"πŸš€ Upscale Image",
variant="primary",
size="lg"
)
with gr.Column(scale=2):
gr.HTML("<h3>πŸ“Š Results</h3>")
result_slider = ImageSlider(
type="pil",
interactive=False,
height=600,
elem_id="result_slider",
label=None
)
download_btn = gr.Button(
"πŸ“₯ Download as PNG",
variant="secondary",
size="lg"
)
# State to store the upscaled image
upscaled_image_state = gr.State()
# Hidden textbox for base64 data
download_data = gr.Textbox(visible=False, elem_id="download_data")
# Event handlers
enhance_btn.click(
fn=enhance_image,
inputs=[
input_image,
image_url,
randomize_seed,
num_inference_steps,
upscale_factor,
denoising_strength,
use_generated_caption,
custom_prompt,
],
outputs=[result_slider, upscaled_image_state]
)
download_btn.click(
fn=download_png,
inputs=[upscaled_image_state],
outputs=download_data
)
gr.HTML("""
<div style="margin-top: 2rem; padding: 1rem; background: #f0f0f0; border-radius: 8px;">
<p><strong>Note:</strong> This upscaler uses the Flux dev model. Users are responsible for obtaining commercial rights if used commercially under their license.</p>
</div>
""")
gr.HTML("""
<style>
#result_slider .slider {
width: 100% !important;
max-width: inherit !important;
}
#result_slider img {
object-fit: contain !important;
width: 100% !important;
height: auto !important;
}
#result_slider .gr-button-tool {
display: none !important;
}
#result_slider .gr-button-undo {
display: none !important;
}
#result_slider .gr-button-clear {
display: none !important;
}
#result_slider .badge-container .badge {
display: none !important;
}
#result_slider .badge-container::before {
content: "Before";
position: absolute;
top: 10px;
left: 10px;
background: rgba(0,0,0,0.5);
color: white;
padding: 5px;
border-radius: 5px;
z-index: 10;
}
#result_slider .badge-container::after {
content: "After";
position: absolute;
top: 10px;
right: 10px;
background: rgba(0,0,0,0.5);
color: white;
padding: 5px;
border-radius: 5px;
z-index: 10;
}
#result_slider .fullscreen img {
object-fit: contain !important;
width: 100vw !important;
height: 100vh !important;
}
</style>
""")
gr.HTML("""
<script>
document.addEventListener('DOMContentLoaded', function() {
const sliderInput = document.querySelector('#result_slider input[type="range"]');
if (sliderInput) {
sliderInput.value = 50;
sliderInput.dispatchEvent(new Event('input'));
}
const downloadData = document.querySelector('#download_data textarea');
if (downloadData) {
const observer = new MutationObserver(() => {
const base64 = downloadData.value;
if (base64) {
const byteCharacters = atob(base64);
const byteNumbers = new Array(byteCharacters.length);
for (let i = 0; i < byteCharacters.length; i++) {
byteNumbers[i] = byteCharacters.charCodeAt(i);
}
const byteArray = new Uint8Array(byteNumbers);
const blob = new Blob([byteArray], {type: 'image/png'});
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = 'upscaled_image.png';
a.click();
URL.revokeObjectURL(url);
// Clear the textbox
downloadData.value = '';
}
});
observer.observe(downloadData, {childList: true, subtree: true, characterData: true});
}
});
</script>
""")
if __name__ == "__main__":
demo.queue().launch(share=True, server_name="0.0.0.0", server_port=7860)