seawolf2357's picture
Create app.py
a02323a verified
raw
history blame
7.03 kB
import gradio as gr
import spaces
import torch
from huggingface_hub import hf_hub_download
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image
from PIL import Image
import os
# Style dictionary
style_type_lora_dict = {
"3D_Chibi": "3D_Chibi_lora_weights.safetensors",
"American_Cartoon": "American_Cartoon_lora_weights.safetensors",
"Chinese_Ink": "Chinese_Ink_lora_weights.safetensors",
"Clay_Toy": "Clay_Toy_lora_weights.safetensors",
"Fabric": "Fabric_lora_weights.safetensors",
"Ghibli": "Ghibli_lora_weights.safetensors",
"Irasutoya": "Irasutoya_lora_weights.safetensors",
"Jojo": "Jojo_lora_weights.safetensors",
"Oil_Painting": "Oil_Painting_lora_weights.safetensors",
"Pixel": "Pixel_lora_weights.safetensors",
"Snoopy": "Snoopy_lora_weights.safetensors",
"Poly": "Poly_lora_weights.safetensors",
"LEGO": "LEGO_lora_weights.safetensors",
"Origami": "Origami_lora_weights.safetensors",
"Pop_Art": "Pop_Art_lora_weights.safetensors",
"Van_Gogh": "Van_Gogh_lora_weights.safetensors",
"Paper_Cutting": "Paper_Cutting_lora_weights.safetensors",
"Line": "Line_lora_weights.safetensors",
"Vector": "Vector_lora_weights.safetensors",
"Picasso": "Picasso_lora_weights.safetensors",
"Macaron": "Macaron_lora_weights.safetensors",
"Rick_Morty": "Rick_Morty_lora_weights.safetensors"
}
# Create LoRAs directory if it doesn't exist
os.makedirs("./LoRAs", exist_ok=True)
# Download all LoRA weights at startup
print("Downloading LoRA weights...")
for style_name, lora_file in style_type_lora_dict.items():
if not os.path.exists(f"./LoRAs/{lora_file}"):
hf_hub_download(
repo_id="Owen777/Kontext-Style-Loras",
filename=lora_file,
local_dir="./LoRAs"
)
print("All LoRA weights downloaded!")
# Initialize pipeline globally (will be loaded to GPU when needed)
pipeline = None
def load_pipeline():
global pipeline
if pipeline is None:
pipeline = FluxKontextPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Kontext-dev",
torch_dtype=torch.bfloat16
)
return pipeline
@spaces.GPU(duration=120) # Request GPU for 120 seconds
def style_transfer(input_image, style_name, num_inference_steps, guidance_scale, seed):
"""
Apply style transfer to the input image using selected style
"""
# Load pipeline and move to GPU
pipe = load_pipeline()
pipe = pipe.to('cuda')
# Set seed for reproducibility
if seed is not None and seed > 0:
generator = torch.Generator(device="cuda").manual_seed(seed)
else:
generator = None
# Resize input image to 1024x1024
if isinstance(input_image, str):
image = load_image(input_image)
else:
image = input_image
image = image.resize((1024, 1024), Image.Resampling.LANCZOS)
# Load the selected LoRA
lora_path = f"./LoRAs/{style_type_lora_dict[style_name]}"
pipe.load_lora_weights(lora_path, adapter_name="style_lora")
pipe.set_adapters(["style_lora"], adapter_weights=[1.0])
# Generate the styled image
prompt = f"Turn this image into the {style_name.replace('_', ' ')} style."
result = pipe(
image=image,
prompt=prompt,
height=1024,
width=1024,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator
)
# Clear GPU memory
torch.cuda.empty_cache()
return result.images[0]
# Create Gradio interface
with gr.Blocks(title="Flux Kontext Style Transfer") as demo:
gr.Markdown("""
# 🎨 Flux Kontext Style Transfer
Transform your images into various artistic styles using FLUX.1-Kontext and style-specific LoRAs.
Upload an image and select a style to apply the transformation!
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Input Image",
type="pil",
height=400
)
style_dropdown = gr.Dropdown(
choices=list(style_type_lora_dict.keys()),
value="3D_Chibi",
label="Select Style",
info="Choose the artistic style to apply"
)
with gr.Accordion("Advanced Settings", open=False):
num_steps = gr.Slider(
minimum=10,
maximum=50,
value=24,
step=1,
label="Number of Inference Steps",
info="More steps = better quality but slower"
)
guidance = gr.Slider(
minimum=1.0,
maximum=10.0,
value=3.5,
step=0.5,
label="Guidance Scale",
info="Higher values = stronger style adherence"
)
seed = gr.Number(
label="Seed",
value=0,
precision=0,
info="Set to 0 for random, or use specific seed for reproducibility"
)
generate_btn = gr.Button("🎨 Apply Style Transfer", variant="primary")
with gr.Column():
output_image = gr.Image(
label="Styled Output",
type="pil",
height=400
)
# Examples
gr.Examples(
examples=[
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "3D_Chibi"],
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Ghibli"],
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Van_Gogh"],
],
inputs=[input_image, style_dropdown],
outputs=output_image,
fn=lambda img, style: style_transfer(img, style, 24, 3.5, 0),
cache_examples=True
)
# Connect the generate button
generate_btn.click(
fn=style_transfer,
inputs=[input_image, style_dropdown, num_steps, guidance, seed],
outputs=output_image
)
gr.Markdown("""
## πŸ“ Notes:
- Processing takes about 30-60 seconds depending on the number of steps
- All images are resized to 1024x1024 for optimal results
- Different styles work better with different types of images
- Try adjusting the advanced settings for better results
## 🎨 Available Styles:
3D Chibi, American Cartoon, Chinese Ink, Clay Toy, Fabric, Ghibli, Irasutoya,
Jojo, Oil Painting, Pixel, Snoopy, Poly, LEGO, Origami, Pop Art, Van Gogh,
Paper Cutting, Line, Vector, Picasso, Macaron, Rick & Morty
""")
if __name__ == "__main__":
demo.launch()