import gradio as gr import spaces import torch from diffusers import FluxKontextPipeline from diffusers.utils import load_image from PIL import Image import os # Style dictionary style_type_lora_dict = { "3D_Chibi": "3D_Chibi_lora_weights.safetensors", "American_Cartoon": "American_Cartoon_lora_weights.safetensors", "Chinese_Ink": "Chinese_Ink_lora_weights.safetensors", "Clay_Toy": "Clay_Toy_lora_weights.safetensors", "Fabric": "Fabric_lora_weights.safetensors", "Ghibli": "Ghibli_lora_weights.safetensors", "Irasutoya": "Irasutoya_lora_weights.safetensors", "Jojo": "Jojo_lora_weights.safetensors", "Oil_Painting": "Oil_Painting_lora_weights.safetensors", "Pixel": "Pixel_lora_weights.safetensors", "Snoopy": "Snoopy_lora_weights.safetensors", "Poly": "Poly_lora_weights.safetensors", "LEGO": "LEGO_lora_weights.safetensors", "Origami": "Origami_lora_weights.safetensors", "Pop_Art": "Pop_Art_lora_weights.safetensors", "Van_Gogh": "Van_Gogh_lora_weights.safetensors", "Paper_Cutting": "Paper_Cutting_lora_weights.safetensors", "Line": "Line_lora_weights.safetensors", "Vector": "Vector_lora_weights.safetensors", "Picasso": "Picasso_lora_weights.safetensors", "Macaron": "Macaron_lora_weights.safetensors", "Rick_Morty": "Rick_Morty_lora_weights.safetensors" } # Style descriptions style_descriptions = { "3D_Chibi": "Cute, miniature 3D character style with big heads", "American_Cartoon": "Classic American animation style", "Chinese_Ink": "Traditional Chinese ink painting aesthetic", "Clay_Toy": "Playful clay/plasticine toy appearance", "Fabric": "Soft, textile-like rendering", "Ghibli": "Studio Ghibli's distinctive anime style", "Irasutoya": "Simple, flat Japanese illustration style", "Jojo": "JoJo's Bizarre Adventure manga style", "Oil_Painting": "Classic oil painting texture and strokes", "Pixel": "Retro pixel art style", "Snoopy": "Peanuts comic strip style", "Poly": "Low-poly 3D geometric style", "LEGO": "LEGO brick construction style", "Origami": "Paper folding art style", "Pop_Art": "Bold, colorful pop art style", "Van_Gogh": "Van Gogh's expressive brushstroke style", "Paper_Cutting": "Paper cut-out art style", "Line": "Clean line art/sketch style", "Vector": "Clean vector graphics style", "Picasso": "Cubist art style inspired by Picasso", "Macaron": "Soft, pastel macaron-like style", "Rick_Morty": "Rick and Morty cartoon style" } # Mapping for thumbnail files thumbnail_mapping = { "3D_Chibi": "3D_Chibi.webp", "American_Cartoon": "american_cartoon.webp", "Chinese_Ink": "chinese_ink.webp", "Clay_Toy": "clay_toy.webp", "Fabric": "fabric.webp", "Ghibli": "ghibli.webp", "Irasutoya": "Irasutoya.webp", "Jojo": "jojo.webp", "Oil_Painting": "oil_painting.webp", "Pixel": "pixel.webp", "Snoopy": "snoopy.webp", "Poly": "poly.webp", "LEGO": "LEGO.webp", "Origami": "origami.webp", "Pop_Art": "pop-art.webp", "Van_Gogh": "van_gogh.webp", "Paper_Cutting": "Paper_Cutting.webp", "Line": "line.webp", "Vector": "vector.webp", "Picasso": "picasso.webp", "Macaron": "Macaron.webp", "Rick_Morty": "Rick_Morty.webp" } # Initialize pipeline globally pipeline = None pipeline_loaded = False def load_pipeline(): global pipeline, pipeline_loaded if pipeline is None: print("Loading FLUX.1-Kontext-dev model...") # HF_TOKEN 자동 감지 token = os.getenv("HF_TOKEN", True) pipeline = FluxKontextPipeline.from_pretrained( "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16, use_auth_token=token ) pipeline_loaded = True return pipeline def load_default_image(): """Load the default man.webp image""" if os.path.exists("man.webp"): try: return Image.open("man.webp") except Exception as e: print(f"Error loading default image: {e}") return None @spaces.GPU(duration=120) def style_transfer(input_image, style_name, prompt_suffix, num_inference_steps, guidance_scale, seed): """ Apply style transfer to the input image using selected style """ if input_image is None: gr.Warning("Please upload an image first!") return None try: # Load pipeline and move to GPU pipe = load_pipeline() pipe = pipe.to('cuda') # Enable memory efficient settings pipe.enable_model_cpu_offload() # Set seed for reproducibility generator = None if seed > 0: generator = torch.Generator(device="cuda").manual_seed(seed) # Process input image if isinstance(input_image, str): image = load_image(input_image) else: image = input_image # Ensure RGB and resize to 1024x1024 image = image.convert("RGB").resize((1024, 1024), Image.Resampling.LANCZOS) # Load the selected LoRA lora_filename = style_type_lora_dict[style_name] # Clear any previously loaded LoRA try: pipe.unload_lora_weights() except: pass # Load LoRA weights pipe.load_lora_weights( "Owen777/Kontext-Style-Loras", weight_name=lora_filename, adapter_name="style" ) pipe.set_adapters(["style"], adapter_weights=[1.0]) # Create prompt for style transformation style_name_readable = style_name.replace('_', ' ') prompt = f"Turn this image into the {style_name_readable} style." if prompt_suffix and prompt_suffix.strip(): prompt += f" {prompt_suffix.strip()}" print(f"Generating with prompt: {prompt}") # Generate the styled image result = pipe( image=image, prompt=prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, height=1024, width=1024 ) # Clear GPU memory torch.cuda.empty_cache() return result.images[0] except Exception as e: print(f"Error: {str(e)}") gr.Error(f"Error during style transfer: {str(e)}") torch.cuda.empty_cache() return None def create_thumbnail_grid(): """Create a gallery of style thumbnails""" thumbnails = [] styles = list(style_type_lora_dict.keys()) for style in styles: thumbnail_file = thumbnail_mapping.get(style, "") if thumbnail_file and os.path.exists(thumbnail_file): try: img = Image.open(thumbnail_file) thumbnails.append((img, style.replace('_', ' '))) except Exception as e: print(f"Error loading thumbnail {thumbnail_file}: {e}") # Create placeholder if thumbnail fails to load placeholder = Image.new('RGB', (256, 256), color='lightgray') thumbnails.append((placeholder, style.replace('_', ' '))) else: # Create placeholder for missing thumbnails placeholder = Image.new('RGB', (256, 256), color='lightgray') thumbnails.append((placeholder, style.replace('_', ' '))) return thumbnails # Create Gradio interface with gr.Blocks(title="Flux Kontext Style LoRA", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🎨 Flux Styler : Flux Kontext Style LoRA") # Thumbnail Grid Section with gr.Row(): style_gallery = gr.Gallery( value=create_thumbnail_grid(), label="Style Thumbnails", show_label=False, elem_id="style_gallery", columns=6, rows=4, object_fit="cover", height="auto", interactive=True, show_download_button=False ) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image( label="Input Image", type="pil", height=400, value=load_default_image() ) style_dropdown = gr.Dropdown( choices=list(style_type_lora_dict.keys()), value="Ghibli", label="Selected Style", elem_id="style_dropdown" ) style_info = gr.Textbox( label="Style Description", value=style_descriptions["Ghibli"], interactive=False, lines=2 ) prompt_suffix = gr.Textbox( label="Additional Instructions (Optional)", placeholder="Add extra details...", lines=2 ) with gr.Accordion("Advanced Settings", open=False): num_steps = gr.Slider( minimum=10, maximum=50, value=24, step=1, label="Inference Steps" ) guidance = gr.Slider( minimum=1.0, maximum=5.0, value=2.5, step=0.1, label="Guidance Scale" ) seed = gr.Number( label="Seed", value=42, precision=0 ) generate_btn = gr.Button("🎨 Transform Image", variant="primary", size="lg") with gr.Column(scale=1): output_image = gr.Image( label="Styled Result", type="pil", height=400 ) # Handle gallery selection def on_gallery_select(evt: gr.SelectData): """Handle thumbnail selection from gallery""" selected_index = evt.index styles = list(style_type_lora_dict.keys()) if 0 <= selected_index < len(styles): selected_style = styles[selected_index] return selected_style, style_descriptions.get(selected_style, "") return None, None style_gallery.select( fn=on_gallery_select, inputs=None, outputs=[style_dropdown, style_info] ) # Update style description when style changes def update_description(style): return style_descriptions.get(style, "") style_dropdown.change( fn=update_description, inputs=[style_dropdown], outputs=[style_info] ) # Connect the generate button generate_btn.click( fn=style_transfer, inputs=[input_image, style_dropdown, prompt_suffix, num_steps, guidance, seed], outputs=output_image ) gr.Markdown(""" --- Powered by ❤️ https://discord.gg/openfreeai """) if __name__ == "__main__": demo.launch()