Spaces:
Running
on
Zero
Running
on
Zero
import warnings | |
import gradio as gr | |
import torch | |
from PIL import Image | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
from diffusers import FluxImg2ImgPipeline | |
import random | |
import numpy as np | |
import os | |
import spaces | |
import huggingface_hub | |
import time | |
huggingface_hub.constants.HF_HUB_DOWNLOAD_TIMEOUT = 60 | |
try: | |
import basicsr | |
# Assume basicsr interpolation setup | |
interpolation = "basicsr" # Placeholder for actual basicsr usage | |
except ImportError: | |
warnings.warn("basicsr not installed; falling back to LANCZOS interpolation.") | |
interpolation = Image.LANCZOS | |
# Initialize models | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.bfloat16 | |
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") | |
# Load FLUX img2img pipeline directly to avoid auto_pipeline issues | |
pipe = FluxImg2ImgPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
torch_dtype=dtype, | |
token=huggingface_token | |
).to(device) | |
pipe.enable_vae_tiling() # To help with memory for large images | |
# Initialize Florence model with float32 to avoid dtype mismatch, with retry | |
for attempt in range(5): | |
try: | |
florence_model = AutoModelForCausalLM.from_pretrained( | |
'microsoft/Florence-2-large', | |
trust_remote_code=True, | |
torch_dtype=torch.float32 | |
).to(device).eval() | |
florence_processor = AutoProcessor.from_pretrained( | |
'microsoft/Florence-2-large', | |
trust_remote_code=True | |
) | |
break | |
except Exception as e: | |
print(f"Attempt {attempt+1} to load Florence-2 failed: {e}") | |
time.sleep(10) | |
else: | |
raise RuntimeError("Failed to load Florence-2 after multiple attempts") | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 2048 | |
# Florence caption function | |
def florence_caption(image): | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
inputs = florence_processor(text="<DETAILED_CAPTION>", images=image, return_tensors="pt").to(device) | |
generated_ids = florence_model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=1024, | |
early_stopping=False, | |
do_sample=False, | |
num_beams=3, | |
) | |
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed_answer = florence_processor.post_process_generation( | |
generated_text, | |
task="<DETAILED_CAPTION>", | |
image_size=(image.width, image.height) | |
) | |
return parsed_answer["<DETAILED_CAPTION>"] | |
# Tiled FLUX img2img function with fix for small dimensions and overlap | |
def tiled_flux_img2img(image, prompt, strength, num_inference_steps, guidance_scale, tile_size=512, overlap=64): | |
width, height = image.size | |
# Resize to multiple of 16 to avoid dimension warnings | |
width = (width // 16) * 16 if width >= 16 else 16 | |
height = (height // 16) * 16 if height >= 16 else 16 | |
if width != image.size[0] or height != image.size[1]: | |
image = image.resize((width, height), resample=interpolation) | |
result = Image.new('RGB', (width, height)) | |
stride = tile_size - overlap | |
# Tile in both directions, handling small sizes | |
for y in range(0, height, stride): | |
for x in range(0, width, stride): | |
tile_left = x | |
tile_top = y | |
tile_right = min(x + tile_size, width) | |
tile_bottom = min(y + tile_size, height) | |
tile = image.crop((tile_left, tile_top, tile_right, tile_bottom)) | |
# Skip if tile is too small | |
if tile.width < 16 or tile.height < 16: | |
continue | |
# Generate with img2img | |
generated_tile = pipe( | |
prompt, | |
image=tile, | |
strength=strength, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps | |
).images[0] | |
generated_tile = generated_tile.resize(tile.size) # Ensure size match | |
# Paste without blend if first tile | |
if x == 0 and y == 0: | |
result.paste(generated_tile, (tile_left, tile_top)) | |
continue | |
# Vertical blend | |
if y > 0: | |
effective_overlap = min(overlap, tile_bottom - tile_top, height - tile_top) | |
if effective_overlap > 0: | |
mask = Image.new('L', (tile_right - tile_left, effective_overlap)) | |
for i in range(mask.width): | |
for j in range(mask.height): | |
divisor = effective_overlap - 1 if effective_overlap > 1 else 1 | |
mask.putpixel((i, j), int(255 * (j / divisor))) | |
blend_region = Image.composite( | |
generated_tile.crop((0, 0, mask.width, mask.height)), | |
result.crop((tile_left, tile_top, tile_right, tile_top + mask.height)), | |
mask | |
) | |
result.paste(blend_region, (tile_left, tile_top)) | |
result.paste(generated_tile.crop((0, effective_overlap, generated_tile.width, generated_tile.height)), (tile_left, tile_top + effective_overlap)) | |
else: | |
result.paste(generated_tile, (tile_left, tile_top)) | |
# Horizontal blend | |
if x > 0: | |
effective_overlap_h = min(overlap, tile_right - tile_left, width - tile_left) | |
if effective_overlap_h > 0: | |
mask_h = Image.new('L', (effective_overlap_h, tile_bottom - tile_top)) | |
for i in range(mask_h.width): | |
for j in range(mask_h.height): | |
divisor_h = effective_overlap_h - 1 if effective_overlap_h > 1 else 1 | |
mask_h.putpixel((i, j), int(255 * (i / divisor_h))) | |
blend_region_h = Image.composite( | |
generated_tile.crop((0, 0, mask_h.width, mask_h.height)), | |
result.crop((tile_left, tile_top, tile_left + mask_h.width, tile_bottom)), | |
mask_h | |
) | |
result.paste(blend_region_h, (tile_left, tile_top)) | |
result.paste(generated_tile.crop((effective_overlap_h, 0, generated_tile.width, generated_tile.height)), (tile_left + effective_overlap_h, tile_top)) | |
else: | |
result.paste(generated_tile, (tile_left, tile_top)) | |
return result | |
# Main enhance function | |
def enhance_image(image, text_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, strength, progress=gr.Progress(track_tqdm=True)): | |
prompt = text_prompt | |
if image is not None: | |
prompt = florence_caption(image) | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator(device=device).manual_seed(seed) | |
# Use tiled if large, else direct | |
if image and (image.size[0] > MAX_IMAGE_SIZE or image.size[1] > MAX_IMAGE_SIZE): | |
output_image = tiled_flux_img2img(image, prompt, strength, num_inference_steps, guidance_scale) | |
else: | |
kw = {} | |
if image is not None: | |
kw['image'] = image | |
kw['strength'] = strength | |
else: | |
kw['width'] = width | |
kw['height'] = height | |
output_image = pipe( | |
prompt, | |
generator=generator, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
**kw | |
).images[0] | |
return output_image, prompt, seed | |
# Gradio interface | |
title = "<h1 align='center'>FLUX Image Enhancer with Florence-2 Captioner</h1>" | |
with gr.Blocks() as demo: | |
gr.HTML(title) | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Upload Image") | |
text_prompt = gr.Textbox(label="Text Prompt (if no image)") | |
strength = gr.Slider(label="Strength", minimum=0.1, maximum=1.0, value=0.8) | |
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, value=5.0) | |
num_inference_steps = gr.Slider(label="Steps", minimum=10, maximum=50, value=20) | |
seed = gr.Number(value=42, label="Seed") | |
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) | |
width = gr.Slider(minimum=256, maximum=1024, step=16, value=512, label="Width") | |
height = gr.Slider(minimum=256, maximum=1024, step=16, value=512, label="Height") | |
submit = gr.Button("Enhance") | |
with gr.Column(): | |
output_image = gr.Image(label="Enhanced Image") | |
output_prompt = gr.Textbox(label="Generated Prompt") | |
output_seed = gr.Number(label="Used Seed") | |
submit.click( | |
enhance_image, | |
inputs=[input_image, text_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, strength], | |
outputs=[output_image, output_prompt, output_seed] | |
) | |
print("✅ All models loaded successfully!") | |
demo.launch(server_port=7860, server_name="0.0.0.0") |