ginipick commited on
Commit
24c51c9
·
verified ·
1 Parent(s): 4db91b0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +357 -0
app.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ from diffusers import FluxKontextPipeline
5
+ from diffusers.utils import load_image
6
+ from PIL import Image
7
+ import os
8
+
9
+ # Style dictionary
10
+ style_type_lora_dict = {
11
+ "3D_Chibi": "3D_Chibi_lora_weights.safetensors",
12
+ "American_Cartoon": "American_Cartoon_lora_weights.safetensors",
13
+ "Chinese_Ink": "Chinese_Ink_lora_weights.safetensors",
14
+ "Clay_Toy": "Clay_Toy_lora_weights.safetensors",
15
+ "Fabric": "Fabric_lora_weights.safetensors",
16
+ "Ghibli": "Ghibli_lora_weights.safetensors",
17
+ "Irasutoya": "Irasutoya_lora_weights.safetensors",
18
+ "Jojo": "Jojo_lora_weights.safetensors",
19
+ "Oil_Painting": "Oil_Painting_lora_weights.safetensors",
20
+ "Pixel": "Pixel_lora_weights.safetensors",
21
+ "Snoopy": "Snoopy_lora_weights.safetensors",
22
+ "Poly": "Poly_lora_weights.safetensors",
23
+ "LEGO": "LEGO_lora_weights.safetensors",
24
+ "Origami": "Origami_lora_weights.safetensors",
25
+ "Pop_Art": "Pop_Art_lora_weights.safetensors",
26
+ "Van_Gogh": "Van_Gogh_lora_weights.safetensors",
27
+ "Paper_Cutting": "Paper_Cutting_lora_weights.safetensors",
28
+ "Line": "Line_lora_weights.safetensors",
29
+ "Vector": "Vector_lora_weights.safetensors",
30
+ "Picasso": "Picasso_lora_weights.safetensors",
31
+ "Macaron": "Macaron_lora_weights.safetensors",
32
+ "Rick_Morty": "Rick_Morty_lora_weights.safetensors"
33
+ }
34
+
35
+ # Style descriptions
36
+ style_descriptions = {
37
+ "3D_Chibi": "Cute, miniature 3D character style with big heads",
38
+ "American_Cartoon": "Classic American animation style",
39
+ "Chinese_Ink": "Traditional Chinese ink painting aesthetic",
40
+ "Clay_Toy": "Playful clay/plasticine toy appearance",
41
+ "Fabric": "Soft, textile-like rendering",
42
+ "Ghibli": "Studio Ghibli's distinctive anime style",
43
+ "Irasutoya": "Simple, flat Japanese illustration style",
44
+ "Jojo": "JoJo's Bizarre Adventure manga style",
45
+ "Oil_Painting": "Classic oil painting texture and strokes",
46
+ "Pixel": "Retro pixel art style",
47
+ "Snoopy": "Peanuts comic strip style",
48
+ "Poly": "Low-poly 3D geometric style",
49
+ "LEGO": "LEGO brick construction style",
50
+ "Origami": "Paper folding art style",
51
+ "Pop_Art": "Bold, colorful pop art style",
52
+ "Van_Gogh": "Van Gogh's expressive brushstroke style",
53
+ "Paper_Cutting": "Paper cut-out art style",
54
+ "Line": "Clean line art/sketch style",
55
+ "Vector": "Clean vector graphics style",
56
+ "Picasso": "Cubist art style inspired by Picasso",
57
+ "Macaron": "Soft, pastel macaron-like style",
58
+ "Rick_Morty": "Rick and Morty cartoon style"
59
+ }
60
+
61
+ # Mapping for thumbnail files
62
+ thumbnail_mapping = {
63
+ "3D_Chibi": "3D_Chibi.webp",
64
+ "American_Cartoon": "american_cartoon.webp",
65
+ "Chinese_Ink": "chinese_ink.webp",
66
+ "Clay_Toy": "clay_toy.webp",
67
+ "Fabric": "fabric.webp",
68
+ "Ghibli": "ghibli.webp",
69
+ "Irasutoya": "Irasutoya.webp",
70
+ "Jojo": "jojo.webp",
71
+ "Oil_Painting": "oil_painting.webp",
72
+ "Pixel": "pixel.webp",
73
+ "Snoopy": "snoopy.webp",
74
+ "Poly": "poly.webp",
75
+ "LEGO": "LEGO.webp",
76
+ "Origami": "origami.webp",
77
+ "Pop_Art": "pop-art.webp",
78
+ "Van_Gogh": "van_gogh.webp",
79
+ "Paper_Cutting": "Paper_Cutting.webp",
80
+ "Line": "line.webp",
81
+ "Vector": "vector.webp",
82
+ "Picasso": "picasso.webp",
83
+ "Macaron": "Macaron.webp",
84
+ "Rick_Morty": "Rick_Morty.webp"
85
+ }
86
+
87
+ # Initialize pipeline globally
88
+ pipeline = None
89
+ pipeline_loaded = False
90
+
91
+ def load_pipeline():
92
+ global pipeline, pipeline_loaded
93
+ if pipeline is None:
94
+ print("Loading FLUX.1-Kontext-dev model...")
95
+ # HF_TOKEN 자동 감지
96
+ token = os.getenv("HF_TOKEN", True)
97
+
98
+ pipeline = FluxKontextPipeline.from_pretrained(
99
+ "black-forest-labs/FLUX.1-Kontext-dev",
100
+ torch_dtype=torch.bfloat16,
101
+ use_auth_token=token
102
+ )
103
+ pipeline_loaded = True
104
+ return pipeline
105
+
106
+ @spaces.GPU(duration=120)
107
+ def style_transfer(input_image, style_name, prompt_suffix, num_inference_steps, guidance_scale, seed):
108
+ """
109
+ Apply style transfer to the input image using selected style
110
+ """
111
+ if input_image is None:
112
+ gr.Warning("Please upload an image first!")
113
+ return None
114
+
115
+ try:
116
+ # Load pipeline and move to GPU
117
+ pipe = load_pipeline()
118
+ pipe = pipe.to('cuda')
119
+
120
+ # Enable memory efficient settings
121
+ pipe.enable_model_cpu_offload()
122
+
123
+ # Set seed for reproducibility
124
+ generator = None
125
+ if seed > 0:
126
+ generator = torch.Generator(device="cuda").manual_seed(seed)
127
+
128
+ # Process input image
129
+ if isinstance(input_image, str):
130
+ image = load_image(input_image)
131
+ else:
132
+ image = input_image
133
+
134
+ # Ensure RGB and resize to 1024x1024
135
+ image = image.convert("RGB").resize((1024, 1024), Image.Resampling.LANCZOS)
136
+
137
+ # Load the selected LoRA
138
+ lora_filename = style_type_lora_dict[style_name]
139
+
140
+ # Clear any previously loaded LoRA
141
+ try:
142
+ pipe.unload_lora_weights()
143
+ except:
144
+ pass
145
+
146
+ # Load LoRA weights
147
+ pipe.load_lora_weights(
148
+ "Owen777/Kontext-Style-Loras",
149
+ weight_name=lora_filename,
150
+ adapter_name="style"
151
+ )
152
+ pipe.set_adapters(["style"], adapter_weights=[1.0])
153
+
154
+ # Create prompt for style transformation
155
+ style_name_readable = style_name.replace('_', ' ')
156
+ prompt = f"Turn this image into the {style_name_readable} style."
157
+ if prompt_suffix and prompt_suffix.strip():
158
+ prompt += f" {prompt_suffix.strip()}"
159
+
160
+ print(f"Generating with prompt: {prompt}")
161
+
162
+ # Generate the styled image
163
+ result = pipe(
164
+ image=image,
165
+ prompt=prompt,
166
+ guidance_scale=guidance_scale,
167
+ num_inference_steps=num_inference_steps,
168
+ generator=generator,
169
+ height=1024,
170
+ width=1024
171
+ )
172
+
173
+ # Clear GPU memory
174
+ torch.cuda.empty_cache()
175
+
176
+ return result.images[0]
177
+
178
+ except Exception as e:
179
+ print(f"Error: {str(e)}")
180
+ gr.Error(f"Error during style transfer: {str(e)}")
181
+ torch.cuda.empty_cache()
182
+ return None
183
+
184
+ def create_thumbnail_html():
185
+ """Create HTML for thumbnail grid"""
186
+ html = '<div style="display: grid; grid-template-columns: repeat(6, 1fr); gap: 10px; max-width: 800px; margin: 0 auto;">'
187
+
188
+ styles = list(style_type_lora_dict.keys())
189
+
190
+ for i, style in enumerate(styles):
191
+ if i >= 24: # Limit to 24 thumbnails for 6x4 grid
192
+ break
193
+
194
+ thumbnail_file = thumbnail_mapping.get(style, "")
195
+ style_readable = style.replace('_', ' ')
196
+
197
+ html += f'''
198
+ <div style="text-align: center; cursor: pointer;" onclick="document.getElementById('style_dropdown').value='{style}';
199
+ var event = new Event('change', {{bubbles: true}});
200
+ document.getElementById('style_dropdown').dispatchEvent(event);">
201
+ <img src="file/{thumbnail_file}" alt="{style_readable}"
202
+ style="width: 100%; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);
203
+ transition: transform 0.2s, box-shadow 0.2s;"
204
+ onmouseover="this.style.transform='scale(1.05)'; this.style.boxShadow='0 4px 8px rgba(0,0,0,0.2)';"
205
+ onmouseout="this.style.transform='scale(1)'; this.style.boxShadow='0 2px 4px rgba(0,0,0,0.1)';">
206
+ <p style="margin: 5px 0; font-size: 12px; font-weight: 500;">{style_readable}</p>
207
+ </div>
208
+ '''
209
+
210
+ # Fill empty slots if needed
211
+ remaining_slots = 24 - len(styles)
212
+ if remaining_slots > 0 and len(styles) < 24:
213
+ for _ in range(remaining_slots):
214
+ html += '<div></div>'
215
+
216
+ html += '</div>'
217
+ return html
218
+
219
+ # Create Gradio interface
220
+ with gr.Blocks(title="FLUX.1 Kontext Style Transfer", theme=gr.themes.Soft()) as demo:
221
+ gr.Markdown("""
222
+ # 🎨 FLUX.1 Kontext Style Transfer
223
+
224
+ Transform your images into various artistic styles using FLUX.1-Kontext-dev and high-quality style LoRAs.
225
+
226
+ This demo uses the official Owen777/Kontext-Style-Loras collection with 22 different artistic styles!
227
+ """)
228
+
229
+ # Thumbnail Grid Section
230
+ gr.Markdown("### 🖼️ Click a style thumbnail to select it:")
231
+ with gr.Row():
232
+ gr.HTML(create_thumbnail_html())
233
+
234
+ gr.Markdown("---")
235
+
236
+ with gr.Row():
237
+ with gr.Column(scale=1):
238
+ input_image = gr.Image(
239
+ label="Upload Image",
240
+ type="pil",
241
+ height=400
242
+ )
243
+
244
+ style_dropdown = gr.Dropdown(
245
+ choices=list(style_type_lora_dict.keys()),
246
+ value="Ghibli",
247
+ label="Selected Style",
248
+ info="Choose from 22 different artistic styles or click a thumbnail above",
249
+ elem_id="style_dropdown"
250
+ )
251
+
252
+ style_info = gr.Textbox(
253
+ label="Style Description",
254
+ value=style_descriptions["Ghibli"],
255
+ interactive=False,
256
+ lines=2
257
+ )
258
+
259
+ prompt_suffix = gr.Textbox(
260
+ label="Additional Instructions (Optional)",
261
+ placeholder="Add extra details like 'make it more colorful' or 'add dramatic lighting'...",
262
+ lines=2
263
+ )
264
+
265
+ with gr.Accordion("Advanced Settings", open=False):
266
+ num_steps = gr.Slider(
267
+ minimum=10,
268
+ maximum=50,
269
+ value=24,
270
+ step=1,
271
+ label="Inference Steps",
272
+ info="More steps = better quality but slower"
273
+ )
274
+
275
+ guidance = gr.Slider(
276
+ minimum=1.0,
277
+ maximum=5.0,
278
+ value=2.5,
279
+ step=0.1,
280
+ label="Guidance Scale",
281
+ info="How closely to follow the prompt (2.5 recommended)"
282
+ )
283
+
284
+ seed = gr.Number(
285
+ label="Seed",
286
+ value=42,
287
+ precision=0,
288
+ info="Set to 0 for random results"
289
+ )
290
+
291
+ generate_btn = gr.Button("🎨 Transform Image", variant="primary", size="lg")
292
+
293
+ with gr.Column(scale=1):
294
+ output_image = gr.Image(
295
+ label="Styled Result",
296
+ type="pil",
297
+ height=400
298
+ )
299
+
300
+ gr.Markdown("""
301
+ ### 💡 Tips:
302
+ - Click any thumbnail above to quickly select a style
303
+ - All images are resized to 1024x1024
304
+ - First run downloads the model (~12GB)
305
+ - Each style transformation takes ~30-60 seconds
306
+ - Try different styles to find the best match!
307
+ - Use additional instructions for fine control
308
+ """)
309
+
310
+ # Update style description when style changes
311
+ def update_description(style):
312
+ return style_descriptions.get(style, "")
313
+
314
+ style_dropdown.change(
315
+ fn=update_description,
316
+ inputs=[style_dropdown],
317
+ outputs=[style_info]
318
+ )
319
+
320
+ # Examples
321
+ gr.Examples(
322
+ examples=[
323
+ ["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Ghibli", ""],
324
+ ["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "3D_Chibi", "make it extra cute"],
325
+ ["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Van_Gogh", "with swirling sky"],
326
+ ["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "Pixel", "8-bit retro game style"],
327
+ ],
328
+ inputs=[input_image, style_dropdown, prompt_suffix],
329
+ outputs=output_image,
330
+ fn=style_transfer,
331
+ cache_examples=False
332
+ )
333
+
334
+ # Connect the generate button
335
+ generate_btn.click(
336
+ fn=style_transfer,
337
+ inputs=[input_image, style_dropdown, prompt_suffix, num_steps, guidance, seed],
338
+ outputs=output_image
339
+ )
340
+
341
+ gr.Markdown("""
342
+ ---
343
+ ### 📚 Available Styles:
344
+
345
+ **Anime/Cartoon**: Ghibli, American Cartoon, Jojo, Snoopy, Rick & Morty, Irasutoya
346
+ **3D/Geometric**: 3D Chibi, Poly, LEGO, Clay Toy
347
+ **Traditional Art**: Chinese Ink, Oil Painting, Van Gogh, Picasso, Pop Art
348
+ **Craft/Material**: Fabric, Origami, Paper Cutting, Macaron
349
+ **Digital/Modern**: Pixel, Line, Vector
350
+
351
+ ---
352
+
353
+ Created with ❤️ using [Owen777/Kontext-Style-Loras](https://huggingface.co/Owen777/Kontext-Style-Loras)
354
+ """)
355
+
356
+ if __name__ == "__main__":
357
+ demo.launch()