import spaces import os import time import torch import gradio as gr from PIL import Image from huggingface_hub import hf_hub_download, list_repo_files from src_inference.pipeline import FluxPipeline from src_inference.lora_helper import set_single_lora BASE_PATH = "black-forest-labs/FLUX.1-dev" LOCAL_LORA_DIR = "./LoRAs" CUSTOM_LORA_DIR = "./Custom_LoRAs" os.makedirs(LOCAL_LORA_DIR, exist_ok=True) os.makedirs(CUSTOM_LORA_DIR, exist_ok=True) # ------------------ DEVICE SETUP ------------------ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 print(f"🚀 Running on device: {device}") if device.type == "cpu": print("⚠️ WARNING: No CUDA detected. Running on CPU. Generation may be slow.") # ------------------ Load Base LoRA ------------------ # print("downloading OmniConsistency base LoRA …") omni_consistency_path = hf_hub_download( repo_id="showlab/OmniConsistency", filename="OmniConsistency.safetensors", local_dir="./Model" ) print("loading base pipeline …") pipe = FluxPipeline.from_pretrained( BASE_PATH, torch_dtype=dtype ).to(device) set_single_lora(pipe.transformer, omni_consistency_path, lora_weights=[1], cond_size=512) # ------------------ Util ------------------ # def clear_cache(transformer): for _, attn_processor in transformer.attn_processors.items(): attn_processor.bank_kv.clear() # ------------------ Generation ------------------ # @spaces.GPU() def generate_image( lora_name, custom_repo_id, prompt, uploaded_image, width, height, guidance_scale, num_inference_steps, seed ): width, height = int(width), int(height) generator = torch.Generator("cpu").manual_seed(seed) # Custom LoRA path if custom_repo_id and custom_repo_id.strip(): repo_id = custom_repo_id.strip() try: files = list_repo_files(repo_id) print("using custom LoRA from:", repo_id) safetensors_files = [f for f in files if f.endswith(".safetensors")] print("found safetensors files:", safetensors_files) if not safetensors_files: raise ValueError("No .safetensors files were found in this repo") fname = safetensors_files[0] lora_path = hf_hub_download( repo_id=repo_id, filename=fname, local_dir=CUSTOM_LORA_DIR, ) except Exception as e: raise gr.Error(f"Load custom LoRA failed: {e}") else: # Built-in LoRA: download only the one selected lora_filename = f"LoRAs/{lora_name}_rank128_bf16.safetensors" lora_path = hf_hub_download( repo_id="showlab/OmniConsistency", filename=lora_filename, local_dir=LOCAL_LORA_DIR ) pipe.unload_lora_weights() try: pipe.load_lora_weights( os.path.dirname(lora_path), weight_name=os.path.basename(lora_path) ) except Exception as e: raise gr.Error(f"Load LoRA failed: {e}") spatial_image = [uploaded_image.convert("RGB")] subject_images = [] start = time.time() out_img = pipe( prompt, height=(height // 8) * 8, width=(width // 8) * 8, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, max_sequence_length=512, generator=generator, spatial_images=spatial_image, subject_images=subject_images, cond_size=512, ).images[0] print(f"inference time: {time.time()-start:.2f}s") clear_cache(pipe.transformer) return uploaded_image, out_img # ------------------ Gradio UI ------------------ # def create_interface(): demo_lora_names = [ "3D_Chibi", "American_Cartoon", "Chinese_Ink", "Clay_Toy", "Fabric", "Ghibli", "Irasutoya", "Jojo", "LEGO", "Line", "Macaron", "Oil_Painting", "Origami", "Paper_Cutting", "Picasso", "Pixel", "Poly", "Pop_Art", "Rick_Morty", "Snoopy", "Van_Gogh", "Vector" ] def update_trigger_word(lora_name, prompt): for name in demo_lora_names: trigger = " ".join(name.split("_")) + " style," prompt = prompt.replace(trigger, "") new_trigger = " ".join(lora_name.split("_")) + " style," return new_trigger + prompt examples = [ ["3D_Chibi", "", "3D Chibi style, Two smiling colleagues high-five at a whiteboard filled with technical notes.", Image.open("./test_imgs/00.png"), 680, 1024, 3.5, 24, 42], ["Clay_Toy", "", "Clay Toy style, A holiday-themed OpenAI team photo full of smiles and warmth.", Image.open("./test_imgs/01.png"), 560, 1024, 3.5, 24, 42], ["American_Cartoon", "", "American Cartoon style, A dramatic subtitle moment from a classic film.", Image.open("./test_imgs/02.png"), 568, 1024, 3.5, 24, 42], ["Origami", "", "Origami style, A Portugal football fan posing with Cristiano Ronaldo.", Image.open("./test_imgs/03.png"), 768, 672, 3.5, 24, 42], ["Vector", "", "Vector style, The distracted boyfriend meme reimagined.", Image.open("./test_imgs/04.png"), 512, 1024, 3.5, 24, 42] ] header = """
""" with gr.Blocks() as demo: gr.Markdown("# OmniConsistency LoRA Image Generation") gr.Markdown("Select a LoRA, enter a prompt, and upload an image to generate a new image with OmniConsistency.") gr.HTML(header) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Upload Image") prompt_box = gr.Textbox( label="Prompt", value="3D Chibi style,", info="Include a style like 'Ghibli style,' in your prompt for better results." ) lora_dropdown = gr.Dropdown( demo_lora_names, label="Select built-in LoRA") custom_repo_box = gr.Textbox( label="Enter Custom LoRA", placeholder="e.g. username/repo_name", info="Overrides built-in LoRA if provided." ) gen_btn = gr.Button("Generate") with gr.Column(scale=1): output_image = gr.ImageSlider(label="Generated Image") with gr.Accordion("Advanced Options", open=False): height_box = gr.Textbox(value="1024", label="Height") width_box = gr.Textbox(value="1024", label="Width") guidance_slider = gr.Slider(0.1, 20, value=3.5, step=0.1, label="Guidance Scale") steps_slider = gr.Slider(1, 50, value=25, step=1, label="Inference Steps") seed_slider = gr.Slider(1, 2_147_483_647, value=42, step=1, label="Seed") lora_dropdown.select(fn=update_trigger_word, inputs=[lora_dropdown, prompt_box], outputs=prompt_box) gr.Examples( examples=examples, inputs=[lora_dropdown, custom_repo_box, prompt_box, image_input, height_box, width_box, guidance_slider, steps_slider, seed_slider], outputs=output_image, fn=generate_image, cache_examples=False, label="Examples" ) gen_btn.click( fn=generate_image, inputs=[lora_dropdown, custom_repo_box, prompt_box, image_input, width_box, height_box, guidance_slider, steps_slider, seed_slider], outputs=output_image ) return demo # ------------------ Run ------------------ # if __name__ == "__main__": demo = create_interface() demo.launch(ssr_mode=False)