ginipick commited on
Commit
0c2c127
·
verified ·
1 Parent(s): dd4f8c0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +533 -0
app.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import spaces
4
+ import torch
5
+ import random
6
+ import json
7
+ import os
8
+ from PIL import Image
9
+ from diffusers import FluxKontextPipeline
10
+ from diffusers.utils import load_image
11
+ from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard
12
+ from safetensors.torch import load_file
13
+ import requests
14
+ import re
15
+
16
+ # Load Kontext model
17
+ MAX_SEED = np.iinfo(np.int32).max
18
+
19
+ pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
20
+
21
+ # Load LoRA data (you'll need to create this JSON file or modify to load your LoRAs)
22
+
23
+ with open("flux_loras.json", "r") as file:
24
+ data = json.load(file)
25
+ flux_loras_raw = [
26
+ {
27
+ "image": item["image"],
28
+ "title": item["title"],
29
+ "repo": item["repo"],
30
+ "trigger_word": item.get("trigger_word", ""),
31
+ "trigger_position": item.get("trigger_position", "prepend"),
32
+ "weights": item.get("weights", "pytorch_lora_weights.safetensors"),
33
+ }
34
+ for item in data
35
+ ]
36
+ print(f"Loaded {len(flux_loras_raw)} LoRAs from JSON")
37
+ # Global variables for LoRA management
38
+ current_lora = None
39
+ lora_cache = {}
40
+
41
+ def load_lora_weights(repo_id, weights_filename):
42
+ """Load LoRA weights from HuggingFace"""
43
+ try:
44
+ if repo_id not in lora_cache:
45
+ lora_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
46
+ lora_cache[repo_id] = lora_path
47
+ return lora_cache[repo_id]
48
+ except Exception as e:
49
+ print(f"Error loading LoRA from {repo_id}: {e}")
50
+ return None
51
+
52
+ def update_selection(selected_state: gr.SelectData, flux_loras):
53
+ """Update UI when a LoRA is selected"""
54
+ if selected_state.index >= len(flux_loras):
55
+ return "### No LoRA selected", gr.update(), None
56
+
57
+ lora_repo = flux_loras[selected_state.index]["repo"]
58
+ trigger_word = flux_loras[selected_state.index]["trigger_word"]
59
+
60
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})"
61
+ new_placeholder = f"optional description, e.g. 'a man with glasses and a beard'"
62
+
63
+ return updated_text, gr.update(placeholder=new_placeholder), selected_state.index
64
+
65
+ def get_huggingface_lora(link):
66
+ """Download LoRA from HuggingFace link"""
67
+ split_link = link.split("/")
68
+ if len(split_link) == 2:
69
+ try:
70
+ model_card = ModelCard.load(link)
71
+ trigger_word = model_card.data.get("instance_prompt", "")
72
+
73
+ fs = HfFileSystem()
74
+ list_of_files = fs.ls(link, detail=False)
75
+ safetensors_file = None
76
+
77
+ for file in list_of_files:
78
+ if file.endswith(".safetensors") and "lora" in file.lower():
79
+ safetensors_file = file.split("/")[-1]
80
+ break
81
+
82
+ if not safetensors_file:
83
+ safetensors_file = "pytorch_lora_weights.safetensors"
84
+
85
+ return split_link[1], safetensors_file, trigger_word
86
+ except Exception as e:
87
+ raise Exception(f"Error loading LoRA: {e}")
88
+ else:
89
+ raise Exception("Invalid HuggingFace repository format")
90
+
91
+ def load_custom_lora(link):
92
+ """Load custom LoRA from user input"""
93
+ if not link:
94
+ return gr.update(visible=False), "", gr.update(visible=False), None, gr.Gallery(selected_index=None), "### Click on a LoRA in the gallery to select it", None
95
+
96
+ try:
97
+ repo_name, weights_file, trigger_word = get_huggingface_lora(link)
98
+
99
+ card = f'''
100
+ <div style="border: 1px solid #ddd; padding: 10px; border-radius: 8px; margin: 10px 0;">
101
+ <span><strong>Loaded custom LoRA:</strong></span>
102
+ <div style="margin-top: 8px;">
103
+ <h4>{repo_name}</h4>
104
+ <small>{"Using: <code><b>"+trigger_word+"</b></code> as trigger word" if trigger_word else "No trigger word found"}</small>
105
+ </div>
106
+ </div>
107
+ '''
108
+
109
+ custom_lora_data = {
110
+ "repo": link,
111
+ "weights": weights_file,
112
+ "trigger_word": trigger_word
113
+ }
114
+
115
+ return gr.update(visible=True), card, gr.update(visible=True), custom_lora_data, gr.Gallery(selected_index=None), f"Custom: {repo_name}", None
116
+
117
+ except Exception as e:
118
+ return gr.update(visible=True), f"Error: {str(e)}", gr.update(visible=False), None, gr.update(), "### Click on a LoRA in the gallery to select it", None
119
+
120
+ def remove_custom_lora():
121
+ """Remove custom LoRA"""
122
+ return "", gr.update(visible=False), gr.update(visible=False), None, None
123
+
124
+ def classify_gallery(flux_loras):
125
+ """Sort gallery by likes"""
126
+ sorted_gallery = sorted(flux_loras, key=lambda x: x.get("likes", 0), reverse=True)
127
+ return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
128
+
129
+ def infer_with_lora_wrapper(input_image, prompt, selected_index, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.75, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
130
+ """Wrapper function to handle state serialization"""
131
+ return infer_with_lora(input_image, prompt, selected_index, custom_lora, seed, randomize_seed, guidance_scale, lora_scale, flux_loras, progress)
132
+
133
+ @spaces.GPU
134
+ def infer_with_lora(input_image, prompt, selected_index, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.0, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
135
+ """Generate image with selected LoRA"""
136
+ global current_lora, pipe
137
+
138
+ if randomize_seed:
139
+ seed = random.randint(0, MAX_SEED)
140
+
141
+ # Determine which LoRA to use
142
+ lora_to_use = None
143
+ if custom_lora:
144
+ lora_to_use = custom_lora
145
+ elif selected_index is not None and flux_loras and selected_index < len(flux_loras):
146
+ lora_to_use = flux_loras[selected_index]
147
+ print(f"Loaded {len(flux_loras)} LoRAs from JSON")
148
+ # Load LoRA if needed
149
+ if lora_to_use and lora_to_use != current_lora:
150
+ try:
151
+ # Unload current LoRA
152
+ if current_lora:
153
+ pipe.unload_lora_weights()
154
+
155
+ # Load new LoRA
156
+ lora_path = load_lora_weights(lora_to_use["repo"], lora_to_use["weights"])
157
+ if lora_path:
158
+ pipe.load_lora_weights(lora_path, adapter_name="selected_lora")
159
+ pipe.set_adapters(["selected_lora"], adapter_weights=[lora_scale])
160
+ print(f"loaded: {lora_path} with scale {lora_scale}")
161
+ current_lora = lora_to_use
162
+
163
+ except Exception as e:
164
+ print(f"Error loading LoRA: {e}")
165
+ # Continue without LoRA
166
+ else:
167
+ print(f"using already loaded lora: {lora_to_use}")
168
+
169
+ input_image = input_image.convert("RGB")
170
+ # Add trigger word to prompt
171
+ trigger_word = lora_to_use["trigger_word"]
172
+ if trigger_word == ", How2Draw":
173
+ prompt = f"create a How2Draw sketch of the person of the photo {prompt}, maintain the facial identity of the person and general features"
174
+ elif trigger_word == ", video game screenshot in the style of THSMS":
175
+ prompt = f"create a video game screenshot in the style of THSMS with the person from the photo, {prompt}. maintain the facial identity of the person and general features"
176
+ else:
177
+ prompt = f"convert the style of this portrait photo to {trigger_word} while maintaining the identity of the person. {prompt}. Make sure to maintain the person's facial identity and features, while still changing the overall style to {trigger_word}."
178
+
179
+ try:
180
+ image = pipe(
181
+ image=input_image,
182
+ prompt=prompt,
183
+ guidance_scale=guidance_scale,
184
+ generator=torch.Generator().manual_seed(seed),
185
+ ).images[0]
186
+
187
+ return image, seed, gr.update(visible=True)
188
+
189
+ except Exception as e:
190
+ print(f"Error during inference: {e}")
191
+ return None, seed, gr.update(visible=False)
192
+
193
+ # CSS styling with beautiful gradient pastel design
194
+ css = """
195
+ /* Global background and container styling */
196
+ .gradio-container {
197
+ background: linear-gradient(135deg, #ffeef8 0%, #e6f3ff 25%, #fff4e6 50%, #f0e6ff 75%, #e6fff9 100%);
198
+ font-family: 'Inter', sans-serif;
199
+ }
200
+
201
+ /* Main app container */
202
+ #main_app {
203
+ display: flex;
204
+ gap: 24px;
205
+ padding: 20px;
206
+ background: rgba(255, 255, 255, 0.85);
207
+ backdrop-filter: blur(20px);
208
+ border-radius: 24px;
209
+ box-shadow: 0 10px 40px rgba(0, 0, 0, 0.08);
210
+ }
211
+
212
+ /* Box column styling */
213
+ #box_column {
214
+ min-width: 400px;
215
+ }
216
+
217
+ /* Gallery box with glassmorphism */
218
+ #gallery_box {
219
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.9) 0%, rgba(240, 248, 255, 0.9) 100%);
220
+ border-radius: 20px;
221
+ padding: 20px;
222
+ box-shadow: 0 8px 32px rgba(135, 206, 250, 0.2);
223
+ border: 1px solid rgba(255, 255, 255, 0.8);
224
+ }
225
+
226
+ /* Input image styling */
227
+ .image-container {
228
+ border-radius: 16px;
229
+ overflow: hidden;
230
+ box-shadow: 0 4px 20px rgba(0, 0, 0, 0.1);
231
+ }
232
+
233
+ /* Gallery styling */
234
+ #gallery {
235
+ overflow-y: scroll !important;
236
+ max-height: 400px;
237
+ padding: 12px;
238
+ background: rgba(255, 255, 255, 0.5);
239
+ border-radius: 16px;
240
+ scrollbar-width: thin;
241
+ scrollbar-color: #ddd6fe #f5f3ff;
242
+ }
243
+
244
+ #gallery::-webkit-scrollbar {
245
+ width: 8px;
246
+ }
247
+
248
+ #gallery::-webkit-scrollbar-track {
249
+ background: #f5f3ff;
250
+ border-radius: 10px;
251
+ }
252
+
253
+ #gallery::-webkit-scrollbar-thumb {
254
+ background: linear-gradient(180deg, #c7d2fe 0%, #ddd6fe 100%);
255
+ border-radius: 10px;
256
+ }
257
+
258
+ /* Selected LoRA text */
259
+ #selected_lora {
260
+ background: linear-gradient(135deg, #818cf8 0%, #a78bfa 100%);
261
+ -webkit-background-clip: text;
262
+ -webkit-text-fill-color: transparent;
263
+ background-clip: text;
264
+ font-weight: 700;
265
+ font-size: 18px;
266
+ text-align: center;
267
+ padding: 12px;
268
+ margin-bottom: 16px;
269
+ }
270
+
271
+ /* Prompt input field */
272
+ #prompt {
273
+ flex-grow: 1;
274
+ border: 2px solid transparent;
275
+ background: linear-gradient(white, white) padding-box,
276
+ linear-gradient(135deg, #a5b4fc 0%, #e9d5ff 100%) border-box;
277
+ border-radius: 12px;
278
+ padding: 12px 16px;
279
+ font-size: 16px;
280
+ transition: all 0.3s ease;
281
+ }
282
+
283
+ #prompt:focus {
284
+ box-shadow: 0 0 0 4px rgba(165, 180, 252, 0.25);
285
+ }
286
+
287
+ /* Run button with animated gradient */
288
+ #run_button {
289
+ background: linear-gradient(135deg, #a78bfa 0%, #818cf8 25%, #60a5fa 50%, #34d399 75%, #fbbf24 100%);
290
+ background-size: 200% 200%;
291
+ animation: gradient-shift 3s ease infinite;
292
+ color: white;
293
+ border: none;
294
+ padding: 12px 32px;
295
+ border-radius: 12px;
296
+ font-weight: 600;
297
+ font-size: 16px;
298
+ cursor: pointer;
299
+ transition: all 0.3s ease;
300
+ box-shadow: 0 4px 20px rgba(167, 139, 250, 0.4);
301
+ }
302
+
303
+ #run_button:hover {
304
+ transform: translateY(-2px);
305
+ box-shadow: 0 6px 30px rgba(167, 139, 250, 0.6);
306
+ }
307
+
308
+ @keyframes gradient-shift {
309
+ 0% { background-position: 0% 50%; }
310
+ 50% { background-position: 100% 50%; }
311
+ 100% { background-position: 0% 50%; }
312
+ }
313
+
314
+ /* Custom LoRA card */
315
+ .custom_lora_card {
316
+ background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
317
+ border: 1px solid #fcd34d;
318
+ border-radius: 12px;
319
+ padding: 16px;
320
+ margin: 12px 0;
321
+ box-shadow: 0 4px 12px rgba(251, 191, 36, 0.2);
322
+ }
323
+
324
+ /* Result image container */
325
+ .output-image {
326
+ border-radius: 16px;
327
+ overflow: hidden;
328
+ box-shadow: 0 8px 32px rgba(0, 0, 0, 0.12);
329
+ margin-top: 20px;
330
+ }
331
+
332
+ /* Accordion styling */
333
+ .accordion {
334
+ background: rgba(249, 250, 251, 0.9);
335
+ border-radius: 12px;
336
+ border: 1px solid rgba(229, 231, 235, 0.8);
337
+ margin-top: 16px;
338
+ }
339
+
340
+ /* Slider styling */
341
+ .slider-container {
342
+ padding: 8px 0;
343
+ }
344
+
345
+ input[type="range"] {
346
+ background: linear-gradient(to right, #e0e7ff 0%, #c7d2fe 100%);
347
+ border-radius: 8px;
348
+ height: 6px;
349
+ }
350
+
351
+ /* Reuse button */
352
+ button:not(#run_button) {
353
+ background: linear-gradient(135deg, #f0abfc 0%, #c084fc 100%);
354
+ color: white;
355
+ border: none;
356
+ padding: 8px 20px;
357
+ border-radius: 8px;
358
+ font-weight: 500;
359
+ cursor: pointer;
360
+ transition: all 0.3s ease;
361
+ }
362
+
363
+ button:not(#run_button):hover {
364
+ transform: translateY(-1px);
365
+ box-shadow: 0 4px 16px rgba(192, 132, 252, 0.4);
366
+ }
367
+
368
+ /* Title styling */
369
+ h1 {
370
+ background: linear-gradient(135deg, #6366f1 0%, #a855f7 25%, #ec4899 50%, #f43f5e 75%, #f59e0b 100%);
371
+ -webkit-background-clip: text;
372
+ -webkit-text-fill-color: transparent;
373
+ background-clip: text;
374
+ text-align: center;
375
+ font-size: 3.5rem;
376
+ font-weight: 800;
377
+ margin-bottom: 8px;
378
+ text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.1);
379
+ }
380
+
381
+ h1 small {
382
+ display: block;
383
+ background: linear-gradient(135deg, #94a3b8 0%, #64748b 100%);
384
+ -webkit-background-clip: text;
385
+ -webkit-text-fill-color: transparent;
386
+ background-clip: text;
387
+ font-size: 1rem;
388
+ font-weight: 500;
389
+ margin-top: 8px;
390
+ }
391
+
392
+ /* Checkbox styling */
393
+ input[type="checkbox"] {
394
+ accent-color: #8b5cf6;
395
+ }
396
+
397
+ /* Label styling */
398
+ label {
399
+ color: #4b5563;
400
+ font-weight: 500;
401
+ }
402
+
403
+ /* Group containers */
404
+ .gr-group {
405
+ background: rgba(255, 255, 255, 0.7);
406
+ border-radius: 16px;
407
+ padding: 20px;
408
+ border: 1px solid rgba(255, 255, 255, 0.9);
409
+ box-shadow: 0 4px 16px rgba(0, 0, 0, 0.05);
410
+ }
411
+ """
412
+
413
+ # Create Gradio interface
414
+ with gr.Blocks(css=css) as demo:
415
+ gr_flux_loras = gr.State(value=flux_loras_raw)
416
+
417
+ title = gr.HTML(
418
+ """<h1>✨ Flux-Kontext FaceLORA
419
+ <small>Transform your portraits with AI-powered style transfer 🎨</small></h1>""",
420
+ )
421
+
422
+ selected_state = gr.State(value=None)
423
+ custom_loaded_lora = gr.State(value=None)
424
+
425
+ with gr.Row(elem_id="main_app"):
426
+ with gr.Column(scale=4, elem_id="box_column"):
427
+ with gr.Group(elem_id="gallery_box"):
428
+ input_image = gr.Image(label="Upload a picture of yourself", type="pil", height=300)
429
+
430
+ gallery = gr.Gallery(
431
+ label="Pick a LoRA",
432
+ allow_preview=False,
433
+ columns=3,
434
+ elem_id="gallery",
435
+ show_share_button=False,
436
+ height=400
437
+ )
438
+
439
+ custom_model = gr.Textbox(
440
+ label="Or enter a custom HuggingFace FLUX LoRA",
441
+ placeholder="e.g., username/lora-name",
442
+ visible=False
443
+ )
444
+ custom_model_card = gr.HTML(visible=False)
445
+ custom_model_button = gr.Button("Remove custom LoRA", visible=False)
446
+
447
+ with gr.Column(scale=5):
448
+ with gr.Row():
449
+ prompt = gr.Textbox(
450
+ label="Editing Prompt",
451
+ show_label=False,
452
+ lines=1,
453
+ max_lines=1,
454
+ placeholder="optional description, e.g. 'a man with glasses and a beard'",
455
+ elem_id="prompt"
456
+ )
457
+ run_button = gr.Button("Generate ✨", elem_id="run_button")
458
+
459
+ result = gr.Image(label="Generated Image", interactive=False)
460
+ reuse_button = gr.Button("🔄 Reuse this image", visible=False)
461
+
462
+ with gr.Accordion("Advanced Settings", open=False):
463
+ lora_scale = gr.Slider(
464
+ label="LoRA Scale",
465
+ minimum=0,
466
+ maximum=2,
467
+ step=0.1,
468
+ value=1.5,
469
+ info="Controls the strength of the LoRA effect"
470
+ )
471
+ seed = gr.Slider(
472
+ label="Seed",
473
+ minimum=0,
474
+ maximum=MAX_SEED,
475
+ step=1,
476
+ value=0,
477
+ )
478
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
479
+ guidance_scale = gr.Slider(
480
+ label="Guidance Scale",
481
+ minimum=1,
482
+ maximum=10,
483
+ step=0.1,
484
+ value=2.5,
485
+ )
486
+
487
+ prompt_title = gr.Markdown(
488
+ value="### Click on a LoRA in the gallery to select it",
489
+ visible=True,
490
+ elem_id="selected_lora",
491
+ )
492
+
493
+ # Event handlers
494
+ custom_model.input(
495
+ fn=load_custom_lora,
496
+ inputs=[custom_model],
497
+ outputs=[custom_model_card, custom_model_card, custom_model_button, custom_loaded_lora, gallery, prompt_title, selected_state],
498
+ )
499
+
500
+ custom_model_button.click(
501
+ fn=remove_custom_lora,
502
+ outputs=[custom_model, custom_model_button, custom_model_card, custom_loaded_lora, selected_state]
503
+ )
504
+
505
+ gallery.select(
506
+ fn=update_selection,
507
+ inputs=[gr_flux_loras],
508
+ outputs=[prompt_title, prompt, selected_state],
509
+ show_progress=False
510
+ )
511
+
512
+ gr.on(
513
+ triggers=[run_button.click, prompt.submit],
514
+ fn=infer_with_lora_wrapper,
515
+ inputs=[input_image, prompt, selected_state, custom_loaded_lora, seed, randomize_seed, guidance_scale, lora_scale, gr_flux_loras],
516
+ outputs=[result, seed, reuse_button]
517
+ )
518
+
519
+ reuse_button.click(
520
+ fn=lambda image: image,
521
+ inputs=[result],
522
+ outputs=[input_image]
523
+ )
524
+
525
+ # Initialize gallery
526
+ demo.load(
527
+ fn=classify_gallery,
528
+ inputs=[gr_flux_loras],
529
+ outputs=[gallery, gr_flux_loras]
530
+ )
531
+
532
+ demo.queue(default_concurrency_limit=None)
533
+ demo.launch()