Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,033 Bytes
a02323a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import gradio as gr
import spaces
import torch
from huggingface_hub import hf_hub_download
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image
from PIL import Image
import os
# Style dictionary
style_type_lora_dict = {
"3D_Chibi": "3D_Chibi_lora_weights.safetensors",
"American_Cartoon": "American_Cartoon_lora_weights.safetensors",
"Chinese_Ink": "Chinese_Ink_lora_weights.safetensors",
"Clay_Toy": "Clay_Toy_lora_weights.safetensors",
"Fabric": "Fabric_lora_weights.safetensors",
"Ghibli": "Ghibli_lora_weights.safetensors",
"Irasutoya": "Irasutoya_lora_weights.safetensors",
"Jojo": "Jojo_lora_weights.safetensors",
"Oil_Painting": "Oil_Painting_lora_weights.safetensors",
"Pixel": "Pixel_lora_weights.safetensors",
"Snoopy": "Snoopy_lora_weights.safetensors",
"Poly": "Poly_lora_weights.safetensors",
"LEGO": "LEGO_lora_weights.safetensors",
"Origami": "Origami_lora_weights.safetensors",
"Pop_Art": "Pop_Art_lora_weights.safetensors",
"Van_Gogh": "Van_Gogh_lora_weights.safetensors",
"Paper_Cutting": "Paper_Cutting_lora_weights.safetensors",
"Line": "Line_lora_weights.safetensors",
"Vector": "Vector_lora_weights.safetensors",
"Picasso": "Picasso_lora_weights.safetensors",
"Macaron": "Macaron_lora_weights.safetensors",
"Rick_Morty": "Rick_Morty_lora_weights.safetensors"
}
# Create LoRAs directory if it doesn't exist
os.makedirs("./LoRAs", exist_ok=True)
# Download all LoRA weights at startup
print("Downloading LoRA weights...")
for style_name, lora_file in style_type_lora_dict.items():
if not os.path.exists(f"./LoRAs/{lora_file}"):
hf_hub_download(
repo_id="Owen777/Kontext-Style-Loras",
filename=lora_file,
local_dir="./LoRAs"
)
print("All LoRA weights downloaded!")
# Initialize pipeline globally (will be loaded to GPU when needed)
pipeline = None
def load_pipeline():
global pipeline
if pipeline is None:
pipeline = FluxKontextPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Kontext-dev",
torch_dtype=torch.bfloat16
)
return pipeline
@spaces.GPU(duration=120) # Request GPU for 120 seconds
def style_transfer(input_image, style_name, num_inference_steps, guidance_scale, seed):
"""
Apply style transfer to the input image using selected style
"""
# Load pipeline and move to GPU
pipe = load_pipeline()
pipe = pipe.to('cuda')
# Set seed for reproducibility
if seed is not None and seed > 0:
generator = torch.Generator(device="cuda").manual_seed(seed)
else:
generator = None
# Resize input image to 1024x1024
if isinstance(input_image, str):
image = load_image(input_image)
else:
image = input_image
image = image.resize((1024, 1024), Image.Resampling.LANCZOS)
# Load the selected LoRA
lora_path = f"./LoRAs/{style_type_lora_dict[style_name]}"
pipe.load_lora_weights(lora_path, adapter_name="style_lora")
pipe.set_adapters(["style_lora"], adapter_weights=[1.0])
# Generate the styled image
prompt = f"Turn this image into the {style_name.replace('_', ' ')} style."
result = pipe(
image=image,
prompt=prompt,
height=1024,
width=1024,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator
)
# Clear GPU memory
torch.cuda.empty_cache()
return result.images[0]
# Create Gradio interface
with gr.Blocks(title="Flux Kontext Style Transfer") as demo:
gr.Markdown("""
# π¨ Flux Kontext Style Transfer
Transform your images into various artistic styles using FLUX.1-Kontext and style-specific LoRAs.
Upload an image and select a style to apply the transformation!
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Input Image",
type="pil",
height=400
)
style_dropdown = gr.Dropdown(
choices=list(style_type_lora_dict.keys()),
value="3D_Chibi",
label="Select Style",
info="Choose the artistic style to apply"
)
with gr.Accordion("Advanced Settings", open=False):
num_steps = gr.Slider(
minimum=10,
maximum=50,
value=24,
step=1,
label="Number of Inference Steps",
info="More steps = better quality but slower"
)
guidance = gr.Slider(
minimum=1.0,
maximum=10.0,
value=3.5,
step=0.5,
label="Guidance Scale",
info="Higher values = stronger style adherence"
)
seed = gr.Number(
label="Seed",
value=0,
precision=0,
info="Set to 0 for random, or use specific seed for reproducibility"
)
generate_btn = gr.Button("π¨ Apply Style Transfer", variant="primary")
with gr.Column():
output_image = gr.Image(
label="Styled Output",
type="pil",
height=400
)
# Examples
gr.Examples(
examples=[
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "3D_Chibi"],
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Ghibli"],
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Van_Gogh"],
],
inputs=[input_image, style_dropdown],
outputs=output_image,
fn=lambda img, style: style_transfer(img, style, 24, 3.5, 0),
cache_examples=True
)
# Connect the generate button
generate_btn.click(
fn=style_transfer,
inputs=[input_image, style_dropdown, num_steps, guidance, seed],
outputs=output_image
)
gr.Markdown("""
## π Notes:
- Processing takes about 30-60 seconds depending on the number of steps
- All images are resized to 1024x1024 for optimal results
- Different styles work better with different types of images
- Try adjusting the advanced settings for better results
## π¨ Available Styles:
3D Chibi, American Cartoon, Chinese Ink, Clay Toy, Fabric, Ghibli, Irasutoya,
Jojo, Oil Painting, Pixel, Snoopy, Poly, LEGO, Origami, Pop Art, Van Gogh,
Paper Cutting, Line, Vector, Picasso, Macaron, Rick & Morty
""")
if __name__ == "__main__":
demo.launch() |