Spaces:
Running
on
Zero
Running
on
Zero
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import spaces
|
3 |
+
import torch
|
4 |
+
from huggingface_hub import hf_hub_download
|
5 |
+
from diffusers import FluxKontextPipeline
|
6 |
+
from diffusers.utils import load_image
|
7 |
+
from PIL import Image
|
8 |
+
import os
|
9 |
+
|
10 |
+
# Style dictionary
|
11 |
+
style_type_lora_dict = {
|
12 |
+
"3D_Chibi": "3D_Chibi_lora_weights.safetensors",
|
13 |
+
"American_Cartoon": "American_Cartoon_lora_weights.safetensors",
|
14 |
+
"Chinese_Ink": "Chinese_Ink_lora_weights.safetensors",
|
15 |
+
"Clay_Toy": "Clay_Toy_lora_weights.safetensors",
|
16 |
+
"Fabric": "Fabric_lora_weights.safetensors",
|
17 |
+
"Ghibli": "Ghibli_lora_weights.safetensors",
|
18 |
+
"Irasutoya": "Irasutoya_lora_weights.safetensors",
|
19 |
+
"Jojo": "Jojo_lora_weights.safetensors",
|
20 |
+
"Oil_Painting": "Oil_Painting_lora_weights.safetensors",
|
21 |
+
"Pixel": "Pixel_lora_weights.safetensors",
|
22 |
+
"Snoopy": "Snoopy_lora_weights.safetensors",
|
23 |
+
"Poly": "Poly_lora_weights.safetensors",
|
24 |
+
"LEGO": "LEGO_lora_weights.safetensors",
|
25 |
+
"Origami": "Origami_lora_weights.safetensors",
|
26 |
+
"Pop_Art": "Pop_Art_lora_weights.safetensors",
|
27 |
+
"Van_Gogh": "Van_Gogh_lora_weights.safetensors",
|
28 |
+
"Paper_Cutting": "Paper_Cutting_lora_weights.safetensors",
|
29 |
+
"Line": "Line_lora_weights.safetensors",
|
30 |
+
"Vector": "Vector_lora_weights.safetensors",
|
31 |
+
"Picasso": "Picasso_lora_weights.safetensors",
|
32 |
+
"Macaron": "Macaron_lora_weights.safetensors",
|
33 |
+
"Rick_Morty": "Rick_Morty_lora_weights.safetensors"
|
34 |
+
}
|
35 |
+
|
36 |
+
# Create LoRAs directory if it doesn't exist
|
37 |
+
os.makedirs("./LoRAs", exist_ok=True)
|
38 |
+
|
39 |
+
# Download all LoRA weights at startup
|
40 |
+
print("Downloading LoRA weights...")
|
41 |
+
for style_name, lora_file in style_type_lora_dict.items():
|
42 |
+
if not os.path.exists(f"./LoRAs/{lora_file}"):
|
43 |
+
hf_hub_download(
|
44 |
+
repo_id="Owen777/Kontext-Style-Loras",
|
45 |
+
filename=lora_file,
|
46 |
+
local_dir="./LoRAs"
|
47 |
+
)
|
48 |
+
print("All LoRA weights downloaded!")
|
49 |
+
|
50 |
+
# Initialize pipeline globally (will be loaded to GPU when needed)
|
51 |
+
pipeline = None
|
52 |
+
|
53 |
+
def load_pipeline():
|
54 |
+
global pipeline
|
55 |
+
if pipeline is None:
|
56 |
+
pipeline = FluxKontextPipeline.from_pretrained(
|
57 |
+
"black-forest-labs/FLUX.1-Kontext-dev",
|
58 |
+
torch_dtype=torch.bfloat16
|
59 |
+
)
|
60 |
+
return pipeline
|
61 |
+
|
62 |
+
@spaces.GPU(duration=120) # Request GPU for 120 seconds
|
63 |
+
def style_transfer(input_image, style_name, num_inference_steps, guidance_scale, seed):
|
64 |
+
"""
|
65 |
+
Apply style transfer to the input image using selected style
|
66 |
+
"""
|
67 |
+
# Load pipeline and move to GPU
|
68 |
+
pipe = load_pipeline()
|
69 |
+
pipe = pipe.to('cuda')
|
70 |
+
|
71 |
+
# Set seed for reproducibility
|
72 |
+
if seed is not None and seed > 0:
|
73 |
+
generator = torch.Generator(device="cuda").manual_seed(seed)
|
74 |
+
else:
|
75 |
+
generator = None
|
76 |
+
|
77 |
+
# Resize input image to 1024x1024
|
78 |
+
if isinstance(input_image, str):
|
79 |
+
image = load_image(input_image)
|
80 |
+
else:
|
81 |
+
image = input_image
|
82 |
+
|
83 |
+
image = image.resize((1024, 1024), Image.Resampling.LANCZOS)
|
84 |
+
|
85 |
+
# Load the selected LoRA
|
86 |
+
lora_path = f"./LoRAs/{style_type_lora_dict[style_name]}"
|
87 |
+
pipe.load_lora_weights(lora_path, adapter_name="style_lora")
|
88 |
+
pipe.set_adapters(["style_lora"], adapter_weights=[1.0])
|
89 |
+
|
90 |
+
# Generate the styled image
|
91 |
+
prompt = f"Turn this image into the {style_name.replace('_', ' ')} style."
|
92 |
+
|
93 |
+
result = pipe(
|
94 |
+
image=image,
|
95 |
+
prompt=prompt,
|
96 |
+
height=1024,
|
97 |
+
width=1024,
|
98 |
+
num_inference_steps=num_inference_steps,
|
99 |
+
guidance_scale=guidance_scale,
|
100 |
+
generator=generator
|
101 |
+
)
|
102 |
+
|
103 |
+
# Clear GPU memory
|
104 |
+
torch.cuda.empty_cache()
|
105 |
+
|
106 |
+
return result.images[0]
|
107 |
+
|
108 |
+
# Create Gradio interface
|
109 |
+
with gr.Blocks(title="Flux Kontext Style Transfer") as demo:
|
110 |
+
gr.Markdown("""
|
111 |
+
# π¨ Flux Kontext Style Transfer
|
112 |
+
|
113 |
+
Transform your images into various artistic styles using FLUX.1-Kontext and style-specific LoRAs.
|
114 |
+
|
115 |
+
Upload an image and select a style to apply the transformation!
|
116 |
+
""")
|
117 |
+
|
118 |
+
with gr.Row():
|
119 |
+
with gr.Column():
|
120 |
+
input_image = gr.Image(
|
121 |
+
label="Input Image",
|
122 |
+
type="pil",
|
123 |
+
height=400
|
124 |
+
)
|
125 |
+
|
126 |
+
style_dropdown = gr.Dropdown(
|
127 |
+
choices=list(style_type_lora_dict.keys()),
|
128 |
+
value="3D_Chibi",
|
129 |
+
label="Select Style",
|
130 |
+
info="Choose the artistic style to apply"
|
131 |
+
)
|
132 |
+
|
133 |
+
with gr.Accordion("Advanced Settings", open=False):
|
134 |
+
num_steps = gr.Slider(
|
135 |
+
minimum=10,
|
136 |
+
maximum=50,
|
137 |
+
value=24,
|
138 |
+
step=1,
|
139 |
+
label="Number of Inference Steps",
|
140 |
+
info="More steps = better quality but slower"
|
141 |
+
)
|
142 |
+
|
143 |
+
guidance = gr.Slider(
|
144 |
+
minimum=1.0,
|
145 |
+
maximum=10.0,
|
146 |
+
value=3.5,
|
147 |
+
step=0.5,
|
148 |
+
label="Guidance Scale",
|
149 |
+
info="Higher values = stronger style adherence"
|
150 |
+
)
|
151 |
+
|
152 |
+
seed = gr.Number(
|
153 |
+
label="Seed",
|
154 |
+
value=0,
|
155 |
+
precision=0,
|
156 |
+
info="Set to 0 for random, or use specific seed for reproducibility"
|
157 |
+
)
|
158 |
+
|
159 |
+
generate_btn = gr.Button("π¨ Apply Style Transfer", variant="primary")
|
160 |
+
|
161 |
+
with gr.Column():
|
162 |
+
output_image = gr.Image(
|
163 |
+
label="Styled Output",
|
164 |
+
type="pil",
|
165 |
+
height=400
|
166 |
+
)
|
167 |
+
|
168 |
+
# Examples
|
169 |
+
gr.Examples(
|
170 |
+
examples=[
|
171 |
+
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "3D_Chibi"],
|
172 |
+
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Ghibli"],
|
173 |
+
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Van_Gogh"],
|
174 |
+
],
|
175 |
+
inputs=[input_image, style_dropdown],
|
176 |
+
outputs=output_image,
|
177 |
+
fn=lambda img, style: style_transfer(img, style, 24, 3.5, 0),
|
178 |
+
cache_examples=True
|
179 |
+
)
|
180 |
+
|
181 |
+
# Connect the generate button
|
182 |
+
generate_btn.click(
|
183 |
+
fn=style_transfer,
|
184 |
+
inputs=[input_image, style_dropdown, num_steps, guidance, seed],
|
185 |
+
outputs=output_image
|
186 |
+
)
|
187 |
+
|
188 |
+
gr.Markdown("""
|
189 |
+
## π Notes:
|
190 |
+
- Processing takes about 30-60 seconds depending on the number of steps
|
191 |
+
- All images are resized to 1024x1024 for optimal results
|
192 |
+
- Different styles work better with different types of images
|
193 |
+
- Try adjusting the advanced settings for better results
|
194 |
+
|
195 |
+
## π¨ Available Styles:
|
196 |
+
3D Chibi, American Cartoon, Chinese Ink, Clay Toy, Fabric, Ghibli, Irasutoya,
|
197 |
+
Jojo, Oil Painting, Pixel, Snoopy, Poly, LEGO, Origami, Pop Art, Van Gogh,
|
198 |
+
Paper Cutting, Line, Vector, Picasso, Macaron, Rick & Morty
|
199 |
+
""")
|
200 |
+
|
201 |
+
if __name__ == "__main__":
|
202 |
+
demo.launch()
|