Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import random | |
import warnings | |
import os | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import torch | |
from diffusers import FluxControlNetModel, FluxControlNetPipeline | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
from gradio_imageslider import ImageSlider | |
from PIL import Image | |
from huggingface_hub import snapshot_download | |
import requests | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 800px; | |
} | |
.main-header { | |
text-align: center; | |
margin-bottom: 2rem; | |
} | |
""" | |
# Device setup | |
if torch.cuda.is_available(): | |
power_device = "GPU" | |
device = "cuda" | |
else: | |
power_device = "CPU" | |
device = "cpu" | |
# Get HuggingFace token | |
huggingface_token = os.getenv("HF_TOKEN") | |
# Download FLUX model | |
print("π₯ Downloading FLUX model...") | |
model_path = snapshot_download( | |
repo_id="black-forest-labs/FLUX.1-dev", | |
repo_type="model", | |
ignore_patterns=["*.md", "*..gitattributes"], | |
local_dir="FLUX.1-dev", | |
token=huggingface_token, | |
) | |
# Load Florence-2 model for image captioning | |
print("π₯ Loading Florence-2 model...") | |
florence_model = AutoModelForCausalLM.from_pretrained( | |
"microsoft/Florence-2-large", | |
torch_dtype=torch.float16, | |
trust_remote_code=True, | |
attn_implementation="eager" # Fix for SDPA compatibility issue | |
).to(device) | |
florence_processor = AutoProcessor.from_pretrained( | |
"microsoft/Florence-2-large", | |
trust_remote_code=True | |
) | |
# Load FLUX ControlNet pipeline | |
print("π₯ Loading FLUX ControlNet...") | |
controlnet = FluxControlNetModel.from_pretrained( | |
"jasperai/Flux.1-dev-Controlnet-Upscaler", | |
torch_dtype=torch.bfloat16 | |
).to(device) | |
pipe = FluxControlNetPipeline.from_pretrained( | |
model_path, | |
controlnet=controlnet, | |
torch_dtype=torch.bfloat16 | |
) | |
pipe.to(device) | |
print("β All models loaded successfully!") | |
MAX_SEED = 1000000 | |
MAX_PIXEL_BUDGET = 1024 * 1024 | |
def generate_caption(image): | |
"""Generate detailed caption using Florence-2""" | |
try: | |
task_prompt = "<MORE_DETAILED_CAPTION>" | |
prompt = task_prompt | |
inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(device) | |
generated_ids = florence_model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=1024, | |
num_beams=3, | |
do_sample=True, | |
) | |
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed_answer = florence_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height)) | |
caption = parsed_answer[task_prompt] | |
return caption | |
except Exception as e: | |
print(f"Caption generation failed: {e}") | |
return "a high quality detailed image" | |
def process_input(input_image, upscale_factor): | |
"""Process input image and handle size constraints""" | |
w, h = input_image.size | |
w_original, h_original = w, h | |
aspect_ratio = w / h | |
was_resized = False | |
if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET: | |
warnings.warn( | |
f"Requested output image is too large ({w * upscale_factor}x{h * upscale_factor}). Resizing to fit budget." | |
) | |
gr.Info( | |
f"Requested output image is too large. Resizing input to fit within pixel budget." | |
) | |
input_image = input_image.resize( | |
( | |
int(aspect_ratio * MAX_PIXEL_BUDGET**0.5 // upscale_factor), | |
int(MAX_PIXEL_BUDGET**0.5 // aspect_ratio // upscale_factor), | |
) | |
) | |
was_resized = True | |
# Resize to multiple of 8 | |
w, h = input_image.size | |
w = w - w % 8 | |
h = h - h % 8 | |
return input_image.resize((w, h)), w_original, h_original, was_resized | |
def load_image_from_url(url): | |
"""Load image from URL""" | |
try: | |
response = requests.get(url) | |
response.raise_for_status() | |
return Image.open(requests.get(url, stream=True).raw) | |
except Exception as e: | |
raise gr.Error(f"Failed to load image from URL: {e}") | |
def enhance_image( | |
image_input, | |
image_url, | |
seed, | |
randomize_seed, | |
num_inference_steps, | |
upscale_factor, | |
controlnet_conditioning_scale, | |
guidance_scale, | |
use_generated_caption, | |
custom_prompt, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
"""Main enhancement function""" | |
# Handle image input | |
if image_input is not None: | |
input_image = image_input | |
elif image_url: | |
input_image = load_image_from_url(image_url) | |
else: | |
raise gr.Error("Please provide an image (upload or URL)") | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
true_input_image = input_image | |
# Process input image | |
input_image, w_original, h_original, was_resized = process_input( | |
input_image, upscale_factor | |
) | |
# Generate caption if requested | |
if use_generated_caption: | |
gr.Info("π Generating image caption...") | |
generated_caption = generate_caption(input_image) | |
prompt = generated_caption | |
else: | |
prompt = custom_prompt if custom_prompt.strip() else "" | |
# Rescale with upscale factor | |
w, h = input_image.size | |
control_image = input_image.resize((w * upscale_factor, h * upscale_factor)) | |
generator = torch.Generator().manual_seed(seed) | |
gr.Info("π Upscaling image...") | |
# Generate upscaled image | |
image = pipe( | |
prompt=prompt, | |
control_image=control_image, | |
controlnet_conditioning_scale=controlnet_conditioning_scale, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
height=control_image.size[1], | |
width=control_image.size[0], | |
generator=generator, | |
).images[0] | |
if was_resized: | |
gr.Info(f"π Resizing output to target size: {w_original * upscale_factor}x{h_original * upscale_factor}") | |
# Resize to target desired size | |
final_image = image.resize((w_original * upscale_factor, h_original * upscale_factor)) | |
return [true_input_image, final_image, seed, generated_caption if use_generated_caption else ""] | |
# Create Gradio interface | |
with gr.Blocks(css=css, title="π¨ AI Image Enhancer - Florence-2 + FLUX") as demo: | |
gr.HTML(""" | |
<div class="main-header"> | |
<h1>π¨ AI Image Enhancer</h1> | |
<p>Upload an image or provide a URL to enhance it using Florence-2 captioning and FLUX upscaling</p> | |
<p>Currently running on <strong>{}</strong></p> | |
</div> | |
""".format(power_device)) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.HTML("<h3>π€ Input</h3>") | |
with gr.Tabs(): | |
with gr.TabItem("π Upload Image"): | |
input_image = gr.Image( | |
label="Upload Image", | |
type="pil", | |
height=300 | |
) | |
with gr.TabItem("π Image URL"): | |
image_url = gr.Textbox( | |
label="Image URL", | |
placeholder="https://example.com/image.jpg", | |
value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg" | |
) | |
gr.HTML("<h3>ποΈ Caption Settings</h3>") | |
use_generated_caption = gr.Checkbox( | |
label="Use AI-generated caption (Florence-2)", | |
value=True, | |
info="Generate detailed caption automatically" | |
) | |
custom_prompt = gr.Textbox( | |
label="Custom Prompt (optional)", | |
placeholder="Enter custom prompt or leave empty for generated caption", | |
lines=2 | |
) | |
gr.HTML("<h3>βοΈ Enhancement Settings</h3>") | |
upscale_factor = gr.Slider( | |
label="Upscale Factor", | |
minimum=1, | |
maximum=4, | |
step=1, | |
value=2, | |
info="How much to upscale the image" | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of Inference Steps", | |
minimum=8, | |
maximum=50, | |
step=1, | |
value=28, | |
info="More steps = better quality but slower" | |
) | |
controlnet_conditioning_scale = gr.Slider( | |
label="ControlNet Conditioning Scale", | |
minimum=0.1, | |
maximum=1.5, | |
step=0.1, | |
value=0.6, | |
info="How much to preserve original structure" | |
) | |
guidance_scale = gr.Slider( | |
label="Guidance Scale", | |
minimum=1.0, | |
maximum=10.0, | |
step=0.5, | |
value=3.5, | |
info="How closely to follow the prompt" | |
) | |
with gr.Row(): | |
randomize_seed = gr.Checkbox( | |
label="Randomize seed", | |
value=True | |
) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=42, | |
interactive=True | |
) | |
enhance_btn = gr.Button( | |
"π Enhance Image", | |
variant="primary", | |
size="lg" | |
) | |
with gr.Column(scale=1): | |
gr.HTML("<h3>π Results</h3>") | |
result_slider = ImageSlider( | |
label="Input / Enhanced", | |
type="pil", | |
interactive=True, | |
height=400 | |
) | |
with gr.Row(): | |
output_seed = gr.Number( | |
label="Used Seed", | |
precision=0, | |
interactive=False | |
) | |
generated_caption_output = gr.Textbox( | |
label="Generated Caption", | |
placeholder="AI-generated caption will appear here...", | |
lines=3, | |
interactive=False | |
) | |
# Examples | |
gr.Examples( | |
examples=[ | |
[None, "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg", 42, False, 28, 2, 0.6, 3.5, True, ""], | |
[None, "https://picsum.photos/512/512", 123, False, 25, 3, 0.8, 4.0, True, ""], | |
], | |
inputs=[ | |
input_image, | |
image_url, | |
seed, | |
randomize_seed, | |
num_inference_steps, | |
upscale_factor, | |
controlnet_conditioning_scale, | |
guidance_scale, | |
use_generated_caption, | |
custom_prompt, | |
] | |
) | |
# Event handler | |
enhance_btn.click( | |
fn=enhance_image, | |
inputs=[ | |
input_image, | |
image_url, | |
seed, | |
randomize_seed, | |
num_inference_steps, | |
upscale_factor, | |
controlnet_conditioning_scale, | |
guidance_scale, | |
use_generated_caption, | |
custom_prompt, | |
], | |
outputs=[result_slider, output_seed, generated_caption_output] | |
) | |
gr.HTML(""" | |
<div style="margin-top: 2rem; padding: 1rem; background: #f0f0f0; border-radius: 8px;"> | |
<h4>π‘ How it works:</h4> | |
<ol> | |
<li><strong>Florence-2</strong> analyzes your image and generates a detailed caption</li> | |
<li><strong>FLUX ControlNet</strong> uses this caption to guide the upscaling process</li> | |
<li>The result is an enhanced, higher-resolution image with improved details</li> | |
</ol> | |
<p><strong>Note:</strong> Due to memory constraints, output is limited to 1024x1024 pixels total budget.</p> | |
</div> | |
""") | |
if __name__ == "__main__": | |
demo.queue().launch(share=True, server_name="0.0.0.0", server_port=7860) |