import warnings import gradio as gr import torch from PIL import Image from transformers import AutoProcessor, AutoModelForCausalLM from diffusers import FluxImg2ImgPipeline import random import numpy as np import os import spaces try: import basicsr # Assume basicsr interpolation setup interpolation = "basicsr" # Placeholder for actual basicsr usage except ImportError: warnings.warn("basicsr not installed; falling back to LANCZOS interpolation.") interpolation = Image.LANCZOS # Initialize models device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 huggingface_token = os.getenv("HUGGINGFACE_TOKEN") # Load FLUX img2img pipeline directly to avoid auto_pipeline issues pipe = FluxImg2ImgPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", torch_dtype=dtype, token=huggingface_token ).to(device) pipe.enable_vae_tiling() # To help with memory for large images # Initialize Florence model with float32 to avoid dtype mismatch florence_model = AutoModelForCausalLM.from_pretrained( 'microsoft/Florence-2-large', trust_remote_code=True, torch_dtype=torch.float32 ).to(device).eval() florence_processor = AutoProcessor.from_pretrained( 'microsoft/Florence-2-large', trust_remote_code=True ) MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 2048 # Florence caption function @spaces.GPU def florence_caption(image): if not isinstance(image, Image.Image): image = Image.fromarray(image) inputs = florence_processor(text="", images=image, return_tensors="pt").to(device) generated_ids = florence_model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, early_stopping=False, do_sample=False, num_beams=3, ) generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] parsed_answer = florence_processor.post_process_generation( generated_text, task="", image_size=(image.width, image.height) ) return parsed_answer[""] # Tiled FLUX img2img function with fix for small dimensions and overlap def tiled_flux_img2img(image, prompt, strength, num_inference_steps, guidance_scale, tile_size=512, overlap=64): width, height = image.size # Resize to multiple of 16 to avoid dimension warnings width = (width // 16) * 16 if width >= 16 else 16 height = (height // 16) * 16 if height >= 16 else 16 if width != image.size[0] or height != image.size[1]: image = image.resize((width, height), resample=interpolation) result = Image.new('RGB', (width, height)) stride = tile_size - overlap # For simplicity, tile in both directions, but handle small sizes for y in range(0, height, stride): for x in range(0, width, stride): tile_left = x tile_top = y tile_right = min(x + tile_size, width) tile_bottom = min(y + tile_size, height) tile = image.crop((tile_left, tile_top, tile_right, tile_bottom)) # Skip if tile is too small if tile.width < 16 or tile.height < 16: continue # Generate with img2img generated_tile = pipe( prompt, image=tile, strength=strength, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps ).images[0] generated_tile = generated_tile.resize(tile.size) # Ensure size match # Paste without blend if first tile if x == 0 and y == 0: result.paste(generated_tile, (tile_left, tile_top)) continue # Blend with previous if overlap if y > 0: # Vertical blend effective_overlap = min(overlap, tile_bottom - tile_top, result.crop((tile_left, tile_top - overlap, tile_right, tile_top)).height) if effective_overlap > 0: mask = Image.new('L', (tile_right - tile_left, effective_overlap)) for i in range(mask.width): for j in range(mask.height): # Fixed: use effective_overlap for division and range mask.putpixel((i, j), int(255 * (j / (effective_overlap - 1 if effective_overlap > 1 else 1)))) # Blend the top part of the tile with the bottom of the previous blend_region = Image.composite( generated_tile.crop((0, 0, mask.width, mask.height)), result.crop((tile_left, tile_top, tile_right, tile_top + mask.height)), mask ) result.paste(blend_region, (tile_left, tile_top)) # Paste the non-overlap part result.paste(generated_tile.crop((0, effective_overlap, generated_tile.width, generated_tile.height)), (tile_left, tile_top + effective_overlap)) else: result.paste(generated_tile, (tile_left, tile_top)) # Similar for horizontal blend (if x > 0), implement analogously if x > 0: # Horizontal blend # Similar logic, but for left overlap, gradient horizontal effective_overlap_h = min(overlap, tile_right - tile_left) if effective_overlap_h > 0: mask_h = Image.new('L', (effective_overlap_h, tile_bottom - tile_top)) for i in range(mask_h.width): for j in range(mask_h.height): mask_h.putpixel((i, j), int(255 * (i / (effective_overlap_h - 1 if effective_overlap_h > 1 else 1)))) # Blend left part blend_region_h = Image.composite( generated_tile.crop((0, 0, mask_h.width, mask_h.height)), result.crop((tile_left, tile_top, tile_left + mask_h.width, tile_bottom)), mask_h ) result.paste(blend_region_h, (tile_left, tile_top)) # Paste non-overlap result.paste(generated_tile.crop((effective_overlap_h, 0, generated_tile.width, generated_tile.height)), (tile_left + effective_overlap_h, tile_top)) else: result.paste(generated_tile, (tile_left, tile_top)) return result # Main enhance function @spaces.GPU(duration=190) def enhance_image(image, text_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, strength, progress=gr.Progress(track_tqdm=True)): prompt = text_prompt if image is not None: prompt = florence_caption(image) if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator(device=device).manual_seed(seed) # Use tiled if large, else direct if image and (image.size[0] > MAX_IMAGE_SIZE or image.size[1] > MAX_IMAGE_SIZE): output_image = tiled_flux_img2img(image, prompt, strength, num_inference_steps, guidance_scale) else: output_image = pipe( prompt, image=image, generator=generator, num_inference_steps=num_inference_steps, width=width if image is None else None, height=height if image is None else None, guidance_scale=guidance_scale, strength=strength if image is not None else 1.0 # For text2img, strength=1.0 ).images[0] return output_image, prompt, seed # Gradio interface title = "

FLUX Image Enhancer with Florence-2 Captioner

" with gr.Blocks() as demo: gr.HTML(title) with gr.Row(): with gr.Column(): input_image = gr.Image(label="Upload Image") text_prompt = gr.Textbox(label="Text Prompt (if no image)") strength = gr.Slider(label="Strength", minimum=0.1, maximum=1.0, value=0.8) guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, value=5.0) num_inference_steps = gr.Slider(label="Steps", minimum=10, maximum=50, value=20) seed = gr.Number(value=42, label="Seed") randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) width = gr.Slider(minimum=256, maximum=1024, step=16, value=512, label="Width") height = gr.Slider(minimum=256, maximum=1024, step=16, value=512, label="Height") submit = gr.Button("Enhance") with gr.Column(): output_image = gr.Image(label="Enhanced Image") output_prompt = gr.Textbox(label="Generated Prompt") output_seed = gr.Number(label="Used Seed") submit.click( enhance_image, inputs=[input_image, text_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, strength], outputs=[output_image, output_prompt, output_seed] ) print("✅ All models loaded successfully!") demo.launch(server_port=7860, server_name="0.0.0.0")