ginipick's picture
Update app.py
286f29e verified
raw
history blame
13.3 kB
import gradio as gr
import spaces
import torch
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image
from PIL import Image
import os
# Style dictionary
style_type_lora_dict = {
"3D_Chibi": "3D_Chibi_lora_weights.safetensors",
"American_Cartoon": "American_Cartoon_lora_weights.safetensors",
"Chinese_Ink": "Chinese_Ink_lora_weights.safetensors",
"Clay_Toy": "Clay_Toy_lora_weights.safetensors",
"Fabric": "Fabric_lora_weights.safetensors",
"Ghibli": "Ghibli_lora_weights.safetensors",
"Irasutoya": "Irasutoya_lora_weights.safetensors",
"Jojo": "Jojo_lora_weights.safetensors",
"Oil_Painting": "Oil_Painting_lora_weights.safetensors",
"Pixel": "Pixel_lora_weights.safetensors",
"Snoopy": "Snoopy_lora_weights.safetensors",
"Poly": "Poly_lora_weights.safetensors",
"LEGO": "LEGO_lora_weights.safetensors",
"Origami": "Origami_lora_weights.safetensors",
"Pop_Art": "Pop_Art_lora_weights.safetensors",
"Van_Gogh": "Van_Gogh_lora_weights.safetensors",
"Paper_Cutting": "Paper_Cutting_lora_weights.safetensors",
"Line": "Line_lora_weights.safetensors",
"Vector": "Vector_lora_weights.safetensors",
"Picasso": "Picasso_lora_weights.safetensors",
"Macaron": "Macaron_lora_weights.safetensors",
"Rick_Morty": "Rick_Morty_lora_weights.safetensors"
}
# Style descriptions
style_descriptions = {
"3D_Chibi": "Cute, miniature 3D character style with big heads",
"American_Cartoon": "Classic American animation style",
"Chinese_Ink": "Traditional Chinese ink painting aesthetic",
"Clay_Toy": "Playful clay/plasticine toy appearance",
"Fabric": "Soft, textile-like rendering",
"Ghibli": "Studio Ghibli's distinctive anime style",
"Irasutoya": "Simple, flat Japanese illustration style",
"Jojo": "JoJo's Bizarre Adventure manga style",
"Oil_Painting": "Classic oil painting texture and strokes",
"Pixel": "Retro pixel art style",
"Snoopy": "Peanuts comic strip style",
"Poly": "Low-poly 3D geometric style",
"LEGO": "LEGO brick construction style",
"Origami": "Paper folding art style",
"Pop_Art": "Bold, colorful pop art style",
"Van_Gogh": "Van Gogh's expressive brushstroke style",
"Paper_Cutting": "Paper cut-out art style",
"Line": "Clean line art/sketch style",
"Vector": "Clean vector graphics style",
"Picasso": "Cubist art style inspired by Picasso",
"Macaron": "Soft, pastel macaron-like style",
"Rick_Morty": "Rick and Morty cartoon style"
}
# Mapping for thumbnail files
thumbnail_mapping = {
"3D_Chibi": "3D_Chibi.webp",
"American_Cartoon": "american_cartoon.webp",
"Chinese_Ink": "chinese_ink.webp",
"Clay_Toy": "clay_toy.webp",
"Fabric": "fabric.webp",
"Ghibli": "ghibli.webp",
"Irasutoya": "Irasutoya.webp",
"Jojo": "jojo.webp",
"Oil_Painting": "oil_painting.webp",
"Pixel": "pixel.webp",
"Snoopy": "snoopy.webp",
"Poly": "poly.webp",
"LEGO": "LEGO.webp",
"Origami": "origami.webp",
"Pop_Art": "pop-art.webp",
"Van_Gogh": "van_gogh.webp",
"Paper_Cutting": "Paper_Cutting.webp",
"Line": "line.webp",
"Vector": "vector.webp",
"Picasso": "picasso.webp",
"Macaron": "Macaron.webp",
"Rick_Morty": "Rick_Morty.webp"
}
# Initialize pipeline globally
pipeline = None
pipeline_loaded = False
def load_pipeline():
global pipeline, pipeline_loaded
if pipeline is None:
print("Loading FLUX.1-Kontext-dev model...")
# HF_TOKEN μžλ™ 감지
token = os.getenv("HF_TOKEN", True)
pipeline = FluxKontextPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Kontext-dev",
torch_dtype=torch.bfloat16,
use_auth_token=token
)
pipeline_loaded = True
return pipeline
@spaces.GPU(duration=120)
def style_transfer(input_image, style_name, prompt_suffix, num_inference_steps, guidance_scale, seed):
"""
Apply style transfer to the input image using selected style
"""
if input_image is None:
gr.Warning("Please upload an image first!")
return None
try:
# Load pipeline and move to GPU
pipe = load_pipeline()
pipe = pipe.to('cuda')
# Enable memory efficient settings
pipe.enable_model_cpu_offload()
# Set seed for reproducibility
generator = None
if seed > 0:
generator = torch.Generator(device="cuda").manual_seed(seed)
# Process input image
if isinstance(input_image, str):
image = load_image(input_image)
else:
image = input_image
# Ensure RGB and resize to 1024x1024
image = image.convert("RGB").resize((1024, 1024), Image.Resampling.LANCZOS)
# Load the selected LoRA
lora_filename = style_type_lora_dict[style_name]
# Clear any previously loaded LoRA
try:
pipe.unload_lora_weights()
except:
pass
# Load LoRA weights
pipe.load_lora_weights(
"Owen777/Kontext-Style-Loras",
weight_name=lora_filename,
adapter_name="style"
)
pipe.set_adapters(["style"], adapter_weights=[1.0])
# Create prompt for style transformation
style_name_readable = style_name.replace('_', ' ')
prompt = f"Turn this image into the {style_name_readable} style."
if prompt_suffix and prompt_suffix.strip():
prompt += f" {prompt_suffix.strip()}"
print(f"Generating with prompt: {prompt}")
# Generate the styled image
result = pipe(
image=image,
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
height=1024,
width=1024
)
# Clear GPU memory
torch.cuda.empty_cache()
return result.images[0]
except Exception as e:
print(f"Error: {str(e)}")
gr.Error(f"Error during style transfer: {str(e)}")
torch.cuda.empty_cache()
return None
def select_style(style_name):
"""Handler for thumbnail clicks"""
return style_name, style_descriptions.get(style_name, "")
def create_thumbnail_grid():
"""Create a gallery of style thumbnails"""
thumbnails = []
styles = list(style_type_lora_dict.keys())
for style in styles:
thumbnail_file = thumbnail_mapping.get(style, "")
if thumbnail_file and os.path.exists(thumbnail_file):
try:
img = Image.open(thumbnail_file)
thumbnails.append((img, style.replace('_', ' ')))
except Exception as e:
print(f"Error loading thumbnail {thumbnail_file}: {e}")
# Create placeholder if thumbnail fails to load
placeholder = Image.new('RGB', (256, 256), color='lightgray')
thumbnails.append((placeholder, style.replace('_', ' ')))
else:
# Create placeholder for missing thumbnails
placeholder = Image.new('RGB', (256, 256), color='lightgray')
thumbnails.append((placeholder, style.replace('_', ' ')))
return thumbnails
# Create Gradio interface
with gr.Blocks(title="FLUX.1 Kontext Style Transfer", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎨 FLUX.1 Kontext Style Transfer
Transform your images into various artistic styles using FLUX.1-Kontext-dev and high-quality style LoRAs.
This demo uses the official Owen777/Kontext-Style-Loras collection with 22 different artistic styles!
""")
# Thumbnail Grid Section
gr.Markdown("### πŸ–ΌοΈ Click a style thumbnail to select it:")
with gr.Row():
style_gallery = gr.Gallery(
value=create_thumbnail_grid(),
label="Style Thumbnails",
show_label=False,
elem_id="style_gallery",
columns=6,
rows=4,
object_fit="cover",
height="auto",
interactive=True,
show_download_button=False
)
gr.Markdown("---")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(
label="Upload Image",
type="pil",
height=400
)
style_dropdown = gr.Dropdown(
choices=list(style_type_lora_dict.keys()),
value="Ghibli",
label="Selected Style",
info="Choose from 22 different artistic styles or click a thumbnail above",
elem_id="style_dropdown"
)
style_info = gr.Textbox(
label="Style Description",
value=style_descriptions["Ghibli"],
interactive=False,
lines=2
)
prompt_suffix = gr.Textbox(
label="Additional Instructions (Optional)",
placeholder="Add extra details like 'make it more colorful' or 'add dramatic lighting'...",
lines=2
)
with gr.Accordion("Advanced Settings", open=False):
num_steps = gr.Slider(
minimum=10,
maximum=50,
value=24,
step=1,
label="Inference Steps",
info="More steps = better quality but slower"
)
guidance = gr.Slider(
minimum=1.0,
maximum=5.0,
value=2.5,
step=0.1,
label="Guidance Scale",
info="How closely to follow the prompt (2.5 recommended)"
)
seed = gr.Number(
label="Seed",
value=42,
precision=0,
info="Set to 0 for random results"
)
generate_btn = gr.Button("🎨 Transform Image", variant="primary", size="lg")
with gr.Column(scale=1):
output_image = gr.Image(
label="Styled Result",
type="pil",
height=400
)
gr.Markdown("""
### πŸ’‘ Tips:
- Click any thumbnail above to quickly select a style
- All images are resized to 1024x1024
- First run downloads the model (~12GB)
- Each style transformation takes ~30-60 seconds
- Try different styles to find the best match!
- Use additional instructions for fine control
""")
# Handle gallery selection
def on_gallery_select(evt: gr.SelectData):
"""Handle thumbnail selection from gallery"""
selected_index = evt.index
styles = list(style_type_lora_dict.keys())
if 0 <= selected_index < len(styles):
selected_style = styles[selected_index]
return selected_style, style_descriptions.get(selected_style, "")
return None, None
style_gallery.select(
fn=on_gallery_select,
inputs=None,
outputs=[style_dropdown, style_info]
)
# Update style description when style changes
def update_description(style):
return style_descriptions.get(style, "")
style_dropdown.change(
fn=update_description,
inputs=[style_dropdown],
outputs=[style_info]
)
# Examples
gr.Examples(
examples=[
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Ghibli", ""],
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "3D_Chibi", "make it extra cute"],
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Van_Gogh", "with swirling sky"],
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Pixel", "8-bit retro game style"],
],
inputs=[input_image, style_dropdown, prompt_suffix],
outputs=output_image,
fn=style_transfer,
cache_examples=False
)
# Connect the generate button
generate_btn.click(
fn=style_transfer,
inputs=[input_image, style_dropdown, prompt_suffix, num_steps, guidance, seed],
outputs=output_image
)
gr.Markdown("""
---
### πŸ“š Available Styles:
**Anime/Cartoon**: Ghibli, American Cartoon, Jojo, Snoopy, Rick & Morty, Irasutoya
**3D/Geometric**: 3D Chibi, Poly, LEGO, Clay Toy
**Traditional Art**: Chinese Ink, Oil Painting, Van Gogh, Picasso, Pop Art
**Craft/Material**: Fabric, Origami, Paper Cutting, Macaron
**Digital/Modern**: Pixel, Line, Vector
---
Powered by ❀️ https://discord.gg/openfreeai
""")
if __name__ == "__main__":
demo.launch()