seawolf2357's picture
Update app.py
32c2062 verified
raw
history blame
10.1 kB
import gradio as gr
import spaces
import torch
from huggingface_hub import hf_hub_download
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image
from PIL import Image
import os
import numpy as np
# Style dictionary - μ™„μ „ν•œ μŠ€νƒ€μΌ λͺ©λ‘
style_type_lora_dict = {
"3D_Chibi": "3D_Chibi_lora_weights.safetensors",
"American_Cartoon": "American_Cartoon_lora_weights.safetensors",
"Chinese_Ink": "Chinese_Ink_lora_weights.safetensors",
"Clay_Toy": "Clay_Toy_lora_weights.safetensors",
"Fabric": "Fabric_lora_weights.safetensors",
"Ghibli": "Ghibli_lora_weights.safetensors",
"Irasutoya": "Irasutoya_lora_weights.safetensors",
"Jojo": "Jojo_lora_weights.safetensors",
"Oil_Painting": "Oil_Painting_lora_weights.safetensors",
"Pixel": "Pixel_lora_weights.safetensors",
"Snoopy": "Snoopy_lora_weights.safetensors",
"Poly": "Poly_lora_weights.safetensors",
"LEGO": "LEGO_lora_weights.safetensors",
"Origami": "Origami_lora_weights.safetensors",
"Pop_Art": "Pop_Art_lora_weights.safetensors",
"Van_Gogh": "Van_Gogh_lora_weights.safetensors",
"Paper_Cutting": "Paper_Cutting_lora_weights.safetensors",
"Line": "Line_lora_weights.safetensors",
"Vector": "Vector_lora_weights.safetensors",
"Picasso": "Picasso_lora_weights.safetensors",
"Macaron": "Macaron_lora_weights.safetensors",
"Rick_Morty": "Rick_Morty_lora_weights.safetensors"
}
# Create LoRAs directory if it doesn't exist
os.makedirs("./LoRAs", exist_ok=True)
# Download LoRA weights on demand
def download_lora(style_name):
lora_file = style_type_lora_dict[style_name]
lora_path = f"./LoRAs/{lora_file}"
if not os.path.exists(lora_path):
gr.Info(f"Downloading {style_name} LoRA...")
try:
hf_hub_download(
repo_id="Owen777/Kontext-Style-Loras",
filename=lora_file,
local_dir="./LoRAs"
)
print(f"Downloaded {lora_file}")
except Exception as e:
print(f"Error downloading {lora_file}: {e}")
raise e
return lora_path
# Initialize pipeline globally
pipeline = None
def load_pipeline():
global pipeline
if pipeline is None:
gr.Info("Loading FLUX.1-Kontext model...")
pipeline = FluxKontextPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Kontext-dev",
torch_dtype=torch.bfloat16
)
return pipeline
@spaces.GPU(duration=120) # Request GPU for 120 seconds
def style_transfer(input_image, style_name, prompt_suffix, num_inference_steps, seed):
"""
Apply style transfer to the input image using selected style
"""
if input_image is None:
gr.Warning("Please upload an image first!")
return None
try:
# Load pipeline and move to GPU
pipe = load_pipeline()
pipe = pipe.to('cuda')
# Set seed for reproducibility
if seed > 0:
generator = torch.Generator(device="cuda").manual_seed(seed)
else:
generator = None
# Process input image
if isinstance(input_image, str):
image = load_image(input_image)
else:
image = input_image
# Resize to 1024x1024 (required for Kontext)
image = image.resize((1024, 1024), Image.Resampling.LANCZOS)
# Download and load the selected LoRA
gr.Info(f"Loading {style_name} style...")
lora_path = download_lora(style_name)
pipe.load_lora_weights(lora_path, adapter_name="style")
pipe.set_adapters(["style"], adapter_weights=[1])
# Create prompt
style_name_readable = style_name.replace('_', ' ')
prompt = f"Turn this image into the {style_name_readable} style."
if prompt_suffix:
prompt += f" {prompt_suffix}"
gr.Info("Generating styled image...")
# Generate the styled image
result = pipe(
image=image,
prompt=prompt,
height=1024,
width=1024,
num_inference_steps=num_inference_steps,
generator=generator
)
# Clear GPU memory
torch.cuda.empty_cache()
return result.images[0]
except Exception as e:
gr.Error(f"Error during style transfer: {str(e)}")
torch.cuda.empty_cache()
return None
# Style descriptions
style_descriptions = {
"3D_Chibi": "Cute, miniature 3D character style with big heads",
"American_Cartoon": "Classic American animation style",
"Chinese_Ink": "Traditional Chinese ink painting aesthetic",
"Clay_Toy": "Playful clay/plasticine toy appearance",
"Fabric": "Soft, textile-like rendering",
"Ghibli": "Studio Ghibli's distinctive anime style",
"Irasutoya": "Simple, flat Japanese illustration style",
"Jojo": "JoJo's Bizarre Adventure manga style",
"Oil_Painting": "Classic oil painting texture and strokes",
"Pixel": "Retro pixel art style",
"Snoopy": "Peanuts comic strip style",
"Poly": "Low-poly 3D geometric style",
"LEGO": "LEGO brick construction style",
"Origami": "Paper folding art style",
"Pop_Art": "Bold, colorful pop art style",
"Van_Gogh": "Van Gogh's expressive brushstroke style",
"Paper_Cutting": "Paper cut-out art style",
"Line": "Clean line art/sketch style",
"Vector": "Clean vector graphics style",
"Picasso": "Cubist art style inspired by Picasso",
"Macaron": "Soft, pastel macaron-like style",
"Rick_Morty": "Rick and Morty cartoon style"
}
# Create Gradio interface
with gr.Blocks(title="FLUX.1 Kontext Style Transfer", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎨 FLUX.1 Kontext Style Transfer
Transform your images into various artistic styles using FLUX.1-Kontext-dev and high-quality style LoRAs.
This demo uses the official Owen777/Kontext-Style-Loras collection with 22 different artistic styles!
""")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(
label="Upload Image",
type="pil",
height=400
)
style_dropdown = gr.Dropdown(
choices=list(style_type_lora_dict.keys()),
value="Ghibli",
label="Select Style",
info="Choose from 22 different artistic styles"
)
style_info = gr.Textbox(
label="Style Description",
value=style_descriptions["Ghibli"],
interactive=False,
lines=2
)
prompt_suffix = gr.Textbox(
label="Additional Prompt (Optional)",
placeholder="Add extra details to the transformation...",
lines=2
)
with gr.Accordion("Advanced Settings", open=False):
num_steps = gr.Slider(
minimum=10,
maximum=50,
value=24,
step=1,
label="Inference Steps",
info="More steps = better quality but slower"
)
seed = gr.Number(
label="Seed",
value=42,
info="Set to 0 for random results"
)
generate_btn = gr.Button("🎨 Transform Image", variant="primary", size="lg")
with gr.Column(scale=1):
output_image = gr.Image(
label="Styled Result",
type="pil",
height=400
)
gr.Markdown("""
### πŸ’‘ Tips:
- All images are resized to 1024x1024
- First run may take longer to download the model
- Each style LoRA is ~359MB and downloaded on first use
- Try different styles to find the best match!
""")
# Update style description when style changes
def update_description(style):
return style_descriptions.get(style, "")
style_dropdown.change(
fn=update_description,
inputs=[style_dropdown],
outputs=[style_info]
)
# Examples
gr.Examples(
examples=[
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Ghibli", ""],
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "3D_Chibi", "make it extra cute"],
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Van_Gogh", "with swirling sky"],
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Pixel", "8-bit retro game style"],
],
inputs=[input_image, style_dropdown, prompt_suffix],
outputs=output_image,
fn=lambda img, style, prompt: style_transfer(img, style, prompt, 24, 42),
cache_examples=False
)
# Connect the generate button
generate_btn.click(
fn=style_transfer,
inputs=[input_image, style_dropdown, prompt_suffix, num_steps, seed],
outputs=output_image
)
gr.Markdown("""
---
### πŸ“š Available Styles:
**Anime/Cartoon**: Ghibli, American Cartoon, Jojo, Snoopy, Rick & Morty, Irasutoya
**3D/Geometric**: 3D Chibi, Poly, LEGO, Clay Toy
**Traditional Art**: Chinese Ink, Oil Painting, Van Gogh, Picasso, Pop Art
**Craft/Material**: Fabric, Origami, Paper Cutting, Macaron
**Digital/Modern**: Pixel, Line, Vector
---
Created with ❀️ using [Owen777/Kontext-Style-Loras](https://huggingface.co/Owen777/Kontext-Style-Loras)
""")
if __name__ == "__main__":
demo.launch()