import warnings import gradio as gr import torch from PIL import Image from transformers import AutoProcessor, AutoModelForCausalLM from diffusers import FluxImg2ImgPipeline import random import numpy as np import os import spaces import huggingface_hub import time huggingface_hub.constants.HF_HUB_DOWNLOAD_TIMEOUT = 60 try: import basicsr # Assume basicsr interpolation setup interpolation = "basicsr" # Placeholder for actual basicsr usage except ImportError: warnings.warn("basicsr not installed; falling back to LANCZOS interpolation.") interpolation = Image.LANCZOS # Initialize models device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 huggingface_token = os.getenv("HUGGINGFACE_TOKEN") # Load FLUX img2img pipeline directly to avoid auto_pipeline issues pipe = FluxImg2ImgPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", torch_dtype=dtype, token=huggingface_token ).to(device) pipe.enable_vae_tiling() # To help with memory for large images # Initialize Florence model with float32 to avoid dtype mismatch, with retry for attempt in range(5): try: florence_model = AutoModelForCausalLM.from_pretrained( 'microsoft/Florence-2-large', trust_remote_code=True, torch_dtype=torch.float32 ).to(device).eval() florence_processor = AutoProcessor.from_pretrained( 'microsoft/Florence-2-large', trust_remote_code=True ) break except Exception as e: print(f"Attempt {attempt+1} to load Florence-2 failed: {e}") time.sleep(10) else: raise RuntimeError("Failed to load Florence-2 after multiple attempts") MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 2048 # Florence caption function @spaces.GPU def florence_caption(image): if not isinstance(image, Image.Image): image = Image.fromarray(image) inputs = florence_processor(text="", images=image, return_tensors="pt").to(device) generated_ids = florence_model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, early_stopping=False, do_sample=False, num_beams=3, ) generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] parsed_answer = florence_processor.post_process_generation( generated_text, task="", image_size=(image.width, image.height) ) return parsed_answer[""] # Tiled FLUX img2img function with fix for small dimensions and overlap def tiled_flux_img2img(image, prompt, strength, num_inference_steps, guidance_scale, tile_size=512, overlap=64): width, height = image.size # Resize to multiple of 16 to avoid dimension warnings width = (width // 16) * 16 if width >= 16 else 16 height = (height // 16) * 16 if height >= 16 else 16 if width != image.size[0] or height != image.size[1]: image = image.resize((width, height), resample=interpolation) result = Image.new('RGB', (width, height)) stride = tile_size - overlap # Tile in both directions, handling small sizes for y in range(0, height, stride): for x in range(0, width, stride): tile_left = x tile_top = y tile_right = min(x + tile_size, width) tile_bottom = min(y + tile_size, height) tile = image.crop((tile_left, tile_top, tile_right, tile_bottom)) # Skip if tile is too small if tile.width < 16 or tile.height < 16: continue # Generate with img2img generated_tile = pipe( prompt, image=tile, strength=strength, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps ).images[0] generated_tile = generated_tile.resize(tile.size) # Ensure size match # Paste without blend if first tile if x == 0 and y == 0: result.paste(generated_tile, (tile_left, tile_top)) continue # Vertical blend if y > 0: effective_overlap = min(overlap, tile_bottom - tile_top, height - tile_top) if effective_overlap > 0: mask = Image.new('L', (tile_right - tile_left, effective_overlap)) for i in range(mask.width): for j in range(mask.height): divisor = effective_overlap - 1 if effective_overlap > 1 else 1 mask.putpixel((i, j), int(255 * (j / divisor))) blend_region = Image.composite( generated_tile.crop((0, 0, mask.width, mask.height)), result.crop((tile_left, tile_top, tile_right, tile_top + mask.height)), mask ) result.paste(blend_region, (tile_left, tile_top)) result.paste(generated_tile.crop((0, effective_overlap, generated_tile.width, generated_tile.height)), (tile_left, tile_top + effective_overlap)) else: result.paste(generated_tile, (tile_left, tile_top)) # Horizontal blend if x > 0: effective_overlap_h = min(overlap, tile_right - tile_left, width - tile_left) if effective_overlap_h > 0: mask_h = Image.new('L', (effective_overlap_h, tile_bottom - tile_top)) for i in range(mask_h.width): for j in range(mask_h.height): divisor_h = effective_overlap_h - 1 if effective_overlap_h > 1 else 1 mask_h.putpixel((i, j), int(255 * (i / divisor_h))) blend_region_h = Image.composite( generated_tile.crop((0, 0, mask_h.width, mask_h.height)), result.crop((tile_left, tile_top, tile_left + mask_h.width, tile_bottom)), mask_h ) result.paste(blend_region_h, (tile_left, tile_top)) result.paste(generated_tile.crop((effective_overlap_h, 0, generated_tile.width, generated_tile.height)), (tile_left + effective_overlap_h, tile_top)) else: result.paste(generated_tile, (tile_left, tile_top)) return result # Main enhance function @spaces.GPU(duration=190) def enhance_image(image, text_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, strength, progress=gr.Progress(track_tqdm=True)): prompt = text_prompt if image is not None: prompt = florence_caption(image) if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator(device=device).manual_seed(seed) # Use tiled if large, else direct if image and (image.size[0] > MAX_IMAGE_SIZE or image.size[1] > MAX_IMAGE_SIZE): output_image = tiled_flux_img2img(image, prompt, strength, num_inference_steps, guidance_scale) else: kw = {} if image is not None: kw['image'] = image kw['strength'] = strength else: kw['width'] = width kw['height'] = height output_image = pipe( prompt, generator=generator, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, **kw ).images[0] return output_image, prompt, seed # Gradio interface title = "

FLUX Image Enhancer with Florence-2 Captioner

" with gr.Blocks() as demo: gr.HTML(title) with gr.Row(): with gr.Column(): input_image = gr.Image(label="Upload Image") text_prompt = gr.Textbox(label="Text Prompt (if no image)") strength = gr.Slider(label="Strength", minimum=0.1, maximum=1.0, value=0.8) guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, value=5.0) num_inference_steps = gr.Slider(label="Steps", minimum=10, maximum=50, value=20) seed = gr.Number(value=42, label="Seed") randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) width = gr.Slider(minimum=256, maximum=1024, step=16, value=512, label="Width") height = gr.Slider(minimum=256, maximum=1024, step=16, value=512, label="Height") submit = gr.Button("Enhance") with gr.Column(): output_image = gr.Image(label="Enhanced Image") output_prompt = gr.Textbox(label="Generated Prompt") output_seed = gr.Number(label="Used Seed") submit.click( enhance_image, inputs=[input_image, text_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, strength], outputs=[output_image, output_prompt, output_seed] ) print("✅ All models loaded successfully!") demo.launch(server_port=7860, server_name="0.0.0.0")