prithivMLmods commited on
Commit
0c8e12c
·
verified ·
1 Parent(s): 15dbdac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +419 -378
app.py CHANGED
@@ -1,401 +1,442 @@
1
- import os
2
- import random
3
- import uuid
4
- import json
5
- import time
6
- import asyncio
7
- from threading import Thread
8
-
9
  import gradio as gr
 
10
  import spaces
11
  import torch
12
- import numpy as np
 
 
13
  from PIL import Image
14
- import cv2
15
-
16
- from transformers import (
17
- Qwen2_5_VLForConditionalGeneration,
18
- AutoProcessor,
19
- TextIteratorStreamer,
20
- )
21
- from transformers.image_utils import load_image
22
 
23
- # Constants for text generation
24
- MAX_MAX_NEW_TOKENS = 2048
25
- DEFAULT_MAX_NEW_TOKENS = 1024
26
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
27
 
28
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
29
 
30
- # Load Vision-Matters-7B
31
- MODEL_ID_M = "Yuting6/Vision-Matters-7B"
32
- processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
33
- model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
34
- MODEL_ID_M, trust_remote_code=True,
35
- torch_dtype=torch.float16).to(device).eval()
36
-
37
- # Load ViGaL-7B
38
- MODEL_ID_X = "yunfeixie/ViGaL-7B"
39
- processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
40
- model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
41
- MODEL_ID_X, trust_remote_code=True,
42
- torch_dtype=torch.float16).to(device).eval()
 
43
 
44
- # Load prithivMLmods/WR30a-Deep-7B-0711
45
- MODEL_ID_T = "prithivMLmods/WR30a-Deep-7B-0711"
46
- processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
47
- model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
48
- MODEL_ID_T, trust_remote_code=True,
49
- torch_dtype=torch.float16).to(device).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- # Load Visionary-R1
52
- MODEL_ID_O = "maifoundations/Visionary-R1"
53
- processor_o = AutoProcessor.from_pretrained(MODEL_ID_O, trust_remote_code=True)
54
- model_o = Qwen2_5_VLForConditionalGeneration.from_pretrained(
55
- MODEL_ID_O, trust_remote_code=True,
56
- torch_dtype=torch.float16).to(device).eval()
 
 
 
 
 
 
 
 
 
57
 
58
- #-----------------------------subfolder-----------------------------#
59
- # Load MonkeyOCR-pro-1.2B
60
- MODEL_ID_W = "echo840/MonkeyOCR-pro-1.2B"
61
- SUBFOLDER = "Recognition"
62
- processor_w = AutoProcessor.from_pretrained(MODEL_ID_W, trust_remote_code=True, subfolder=SUBFOLDER)
63
- model_w = Qwen2_5_VLForConditionalGeneration.from_pretrained(
64
- MODEL_ID_W, trust_remote_code=True,
65
- subfolder=SUBFOLDER,
66
- torch_dtype=torch.float16).to(device).eval()
67
- #-----------------------------subfolder-----------------------------#
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- # Function to downsample video frames
70
- def downsample_video(video_path):
71
- """
72
- Downsamples the video to evenly spaced frames.
73
- Each frame is returned as a PIL image along with its timestamp.
74
- """
75
- vidcap = cv2.VideoCapture(video_path)
76
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
77
- fps = vidcap.get(cv2.CAP_PROP_FPS)
78
- frames = []
79
- frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
80
- for i in frame_indices:
81
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
82
- success, image = vidcap.read()
83
- if success:
84
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
85
- pil_image = Image.fromarray(image)
86
- timestamp = round(i / fps, 2)
87
- frames.append((pil_image, timestamp))
88
- vidcap.release()
89
- return frames
 
 
 
 
 
 
 
 
 
 
90
 
91
- # Function to generate text responses based on image input
92
- @spaces.GPU
93
- def generate_image(model_name: str,
94
- text: str,
95
- image: Image.Image,
96
- max_new_tokens: int = 1024,
97
- temperature: float = 0.6,
98
- top_p: float = 0.9,
99
- top_k: int = 50,
100
- repetition_penalty: float = 1.2):
101
- """
102
- Generates responses using the selected model for image input.
103
- """
104
- if model_name == "Vision-Matters-7B":
105
- processor = processor_m
106
- model = model_m
107
- elif model_name == "ViGaL-7B":
108
- processor = processor_x
109
- model = model_x
110
- elif model_name == "Visionary-R1-3B":
111
- processor = processor_o
112
- model = model_o
113
- elif model_name == "WR30a-Deep-7B-0711":
114
- processor = processor_t
115
- model = model_t
116
- elif model_name == "MonkeyOCR-pro-1.2B":
117
- processor = processor_w
118
- model = model_w
119
- else:
120
- yield "Invalid model selected.", "Invalid model selected."
121
- return
122
 
123
- if image is None:
124
- yield "Please upload an image.", "Please upload an image."
125
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
- messages = [{
128
- "role": "user",
129
- "content": [
130
- {"type": "image", "image": image},
131
- {"type": "text", "text": text},
132
- ]
133
- }]
134
- prompt_full = processor.apply_chat_template(messages,
135
- tokenize=False,
136
- add_generation_prompt=True)
137
- inputs = processor(text=[prompt_full],
138
- images=[image],
139
- return_tensors="pt",
140
- padding=True,
141
- truncation=False,
142
- max_length=MAX_INPUT_TOKEN_LENGTH).to(device)
143
- streamer = TextIteratorStreamer(processor,
144
- skip_prompt=True,
145
- skip_special_tokens=True)
146
- generation_kwargs = {
147
- **inputs, "streamer": streamer,
148
- "max_new_tokens": max_new_tokens
149
- }
150
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
151
- thread.start()
152
- buffer = ""
153
- for new_text in streamer:
154
- buffer += new_text
155
- time.sleep(0.01)
156
- yield buffer, buffer
157
 
158
- # Function to generate text responses based on video input
159
  @spaces.GPU
160
- def generate_video(model_name: str,
161
- text: str,
162
- video_path: str,
163
- max_new_tokens: int = 1024,
164
- temperature: float = 0.6,
165
- top_p: float = 0.9,
166
- top_k: int = 50,
167
- repetition_penalty: float = 1.2):
168
- """
169
- Generates responses using the selected model for video input.
170
- """
171
- if model_name == "Vision-Matters-7B":
172
- processor = processor_m
173
- model = model_m
174
- elif model_name == "ViGaL-7B":
175
- processor = processor_x
176
- model = model_x
177
- elif model_name == "Visionary-R1-3B":
178
- processor = processor_o
179
- model = model_o
180
- elif model_name == "WR30a-Deep-7B-0711":
181
- processor = processor_t
182
- model = model_t
183
- elif model_name == "MonkeyOCR-pro-1.2B":
184
- processor = processor_w
185
- model = model_w
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  else:
187
- yield "Invalid model selected.", "Invalid model selected."
188
- return
189
-
190
- if video_path is None:
191
- yield "Please upload a video.", "Please upload a video."
192
- return
193
-
194
- frames = downsample_video(video_path)
195
- messages = [{
196
- "role": "system",
197
- "content": [{"type": "text", "text": "You are a helpful assistant."}]
198
- }, {
199
- "role": "user",
200
- "content": [{"type": "text", "text": text}]
201
- }]
202
- for frame in frames:
203
- image, timestamp = frame
204
- messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
205
- messages[1]["content"].append({"type": "image", "image": image})
206
- inputs = processor.apply_chat_template(
207
- messages,
208
- tokenize=True,
209
- add_generation_prompt=True,
210
- return_dict=True,
211
- return_tensors="pt",
212
- truncation=False,
213
- max_length=MAX_INPUT_TOKEN_LENGTH).to(device)
214
- streamer = TextIteratorStreamer(processor,
215
- skip_prompt=True,
216
- skip_special_tokens=True)
217
- generation_kwargs = {
218
- **inputs,
219
- "streamer": streamer,
220
- "max_new_tokens": max_new_tokens,
221
- "do_sample": True,
222
- "temperature": temperature,
223
- "top_p": top_p,
224
- "top_k": top_k,
225
- "repetition_penalty": repetition_penalty,
226
- }
227
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
228
- thread.start()
229
- buffer = ""
230
- for new_text in streamer:
231
- buffer += new_text
232
- buffer = buffer.replace("<|im_end|>", "")
233
- time.sleep(0.01)
234
- yield buffer, buffer
235
-
236
- # Define examples for image and video inference
237
- image_examples = [
238
- ["Extract the content.", "images/7.png"],
239
- ["Solve the problem to find the value.", "images/1.jpg"],
240
- ["Explain the scene.", "images/6.JPG"],
241
- ["Solve the problem step by step.", "images/2.jpg"],
242
- ["Find the value of 'X'.", "images/3.jpg"],
243
- ["Simplify the expression.", "images/4.jpg"],
244
- ["Solve for the value.", "images/5.png"]
245
- ]
246
-
247
- video_examples = [
248
- ["Explain the video in detail.", "videos/1.mp4"],
249
- ["Explain the video in detail.", "videos/2.mp4"]
250
- ]
251
 
252
- # Updated CSS with the new submit button theme
253
- css = """
254
- .submit-btn {
255
- --clr-font-main: hsla(0 0% 20% / 100);
256
- --btn-bg-1: hsla(194 100% 69% / 1);
257
- --btn-bg-2: hsla(217 100% 56% / 1);
258
- --btn-bg-color: hsla(360 100% 100% / 1);
259
- --radii: 0.5em;
260
- cursor: pointer;
261
- padding: 0.9em 1.4em;
262
- min-width: 120px;
263
- min-height: 44px;
264
- font-size: var(--size, 1rem);
265
- font-weight: 500;
266
- transition: 0.8s;
267
- background-size: 280% auto;
268
- background-image: linear-gradient(
269
- 325deg,
270
- var(--btn-bg-2) 0%,
271
- var(--btn-bg-1) 55%,
272
- var(--btn-bg-2) 90%
273
- );
274
- border: none;
275
- border-radius: var(--radii);
276
- color: var(--btn-bg-color);
277
- box-shadow:
278
- 0px 0px 20px rgba(71, 184, 255, 0.5),
279
- 0px 5px 5px -1px rgba(58, 125, 233, 0.25),
280
- inset 4px 4px 8px rgba(175, 230, 255, 0.5),
281
- inset -4px -4px 8px rgba(19, 95, 216, 0.35);
282
- }
283
- .submit-btn:hover {
284
- background-position: right top;
285
- }
286
- .submit-btn:is(:focus, :focus-visible, :active) {
287
- outline: none;
288
- box-shadow:
289
- 0 0 0 3px var(--btn-bg-color),
290
- 0 0 0 6px var(--btn-bg-2);
291
- }
292
- @media (prefers-reduced-motion: reduce) {
293
- .submit-btn {
294
- transition: linear;
295
- }
296
- }
297
- .canvas-output {
298
- border: 2px solid #4682B4;
299
- border-radius: 10px;
300
- padding: 20px;
301
- }
302
- """
303
 
304
- # Create the Gradio Interface
305
- with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
306
- gr.Markdown(
307
- "# **[Multimodal VLMs [OCR | VQA]](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**"
 
 
308
  )
309
- with gr.Row():
310
- with gr.Column():
311
- with gr.Tabs():
312
- with gr.TabItem("Image Inference"):
313
- image_query = gr.Textbox(
314
- label="Query Input",
315
- placeholder="Enter your query here...")
316
- image_upload = gr.Image(type="pil", label="Image")
317
- image_submit = gr.Button("Submit",
318
- elem_classes="submit-btn")
319
- gr.Examples(examples=image_examples,
320
- inputs=[image_query, image_upload])
321
- with gr.TabItem("Video Inference"):
322
- video_query = gr.Textbox(
323
- label="Query Input",
324
- placeholder="Enter your query here...")
325
- video_upload = gr.Video(label="Video")
326
- video_submit = gr.Button("Submit",
327
- elem_classes="submit-btn")
328
- gr.Examples(examples=video_examples,
329
- inputs=[video_query, video_upload])
330
-
331
- with gr.Accordion("Advanced options", open=False):
332
- max_new_tokens = gr.Slider(label="Max new tokens",
333
- minimum=1,
334
- maximum=MAX_MAX_NEW_TOKENS,
335
- step=1,
336
- value=DEFAULT_MAX_NEW_TOKENS)
337
- temperature = gr.Slider(label="Temperature",
338
- minimum=0.1,
339
- maximum=4.0,
340
- step=0.1,
341
- value=0.6)
342
- top_p = gr.Slider(label="Top-p (nucleus sampling)",
343
- minimum=0.05,
344
- maximum=1.0,
345
- step=0.05,
346
- value=0.9)
347
- top_k = gr.Slider(label="Top-k",
348
- minimum=1,
349
- maximum=1000,
350
- step=1,
351
- value=50)
352
- repetition_penalty = gr.Slider(label="Repetition penalty",
353
- minimum=1.0,
354
- maximum=2.0,
355
- step=0.05,
356
- value=1.2)
357
-
358
- with gr.Column():
359
- with gr.Column(elem_classes="canvas-output"):
360
- gr.Markdown("## Output")
361
- output = gr.Textbox(label="Raw Output Stream",
362
- interactive=False,
363
- lines=2, show_copy_button=True)
364
- with gr.Accordion("(Result.md)", open=False):
365
- markdown_output = gr.Markdown(
366
- label="markup.md")
367
- #download_btn = gr.Button("Download Result.md")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
 
369
- model_choice = gr.Radio(choices=[
370
- "Vision-Matters-7B", "WR30a-Deep-7B-0711",
371
- "ViGaL-7B", "MonkeyOCR-pro-1.2B", "Visionary-R1-3B"
372
- ],
373
- label="Select Model",
374
- value="Vision-Matters-7B")
375
-
376
- gr.Markdown("**Model Info 💻** | [Report Bug](https://huggingface.co/spaces/prithivMLmods/Multimodal-VLMs-5x/discussions)")
377
- gr.Markdown("> [WR30a-Deep-7B-0711](https://huggingface.co/prithivMLmods/WR30a-Deep-7B-0711): wr30a-deep-7b-0711 model is a fine-tuned version of qwen2.5-vl-7b-instruct, optimized for image captioning, visual analysis, and image reasoning. Built on top of the qwen2.5-vl architecture, this experimental model enhances visual comprehension capabilities with focused training on 1,500k image pairs for superior image understanding.")
378
- gr.Markdown("> [MonkeyOCR-pro-1.2B](https://huggingface.co/echo840/MonkeyOCR-pro-1.2B): MonkeyOCR adopts a structure-recognition-relation (SRR) triplet paradigm, which simplifies the multi-tool pipeline of modular approaches while avoiding the inefficiency of using large multimodal models for full-page document processing.")
379
- gr.Markdown("> [Vision Matters 7B](https://huggingface.co/Yuting6/Vision-Matters-7B): vision-matters is a simple visual perturbation framework that can be easily integrated into existing post-training pipelines including sft, dpo, and grpo. our findings highlight the critical role of visual perturbation: better reasoning begins with better seeing.")
380
- gr.Markdown("> [ViGaL 7B](https://huggingface.co/yunfeixie/ViGaL-7B): vigal-7b shows that training a 7b mllm on simple games like snake using reinforcement learning boosts performance on benchmarks like mathvista and mmmu without needing worked solutions or diagrams indicating transferable reasoning skills.")
381
- gr.Markdown("> [Visionary-R1](https://huggingface.co/maifoundations/Visionary-R1): visionary-r1 is a novel framework for training visual language models (vlms) to perform robust visual reasoning using reinforcement learning (rl). unlike traditional approaches that rely heavily on (sft) or (cot) annotations, visionary-r1 leverages only visual question-answer pairs and rl, making the process more scalable and accessible.")
382
- gr.Markdown(">⚠️note: all the models in space are not guaranteed to perform well in video inference use cases.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
 
384
- # Define the submit button actions
385
- image_submit.click(fn=generate_image,
386
- inputs=[
387
- model_choice, image_query, image_upload,
388
- max_new_tokens, temperature, top_p, top_k,
389
- repetition_penalty
390
- ],
391
- outputs=[output, markdown_output])
392
- video_submit.click(fn=generate_video,
393
- inputs=[
394
- model_choice, video_query, video_upload,
395
- max_new_tokens, temperature, top_p, top_k,
396
- repetition_penalty
397
- ],
398
- outputs=[output, markdown_output])
399
 
400
- if __name__ == "__main__":
401
- demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import numpy as np
3
  import spaces
4
  import torch
5
+ import random
6
+ import json
7
+ import os
8
  from PIL import Image
9
+ from diffusers import FluxKontextPipeline
10
+ from diffusers.utils import load_image
11
+ from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, list_repo_files
12
+ from safetensors.torch import load_file
13
+ import requests
14
+ import re
 
 
15
 
16
+ # Load Kontext model
17
+ MAX_SEED = np.iinfo(np.int32).max
 
 
18
 
19
+ pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
20
 
21
+ # Load LoRA data
22
+ flux_loras_raw = [
23
+ {
24
+ "image": "https://huggingface.co/prithivMLmods/FLUX.1-Kontext-Cinematic-Relighting/resolve/main/images/1.png",
25
+ "title": "Kontext Cinematic Relighting",
26
+ "repo": "prithivMLmods/FLUX.1-Kontext-Cinematic-Relighting",
27
+ "trigger_word": "Cinematic Relighting, Relight this portrait with warm, cinematic indoor lighting. Add soft amber highlights and gentle shadows to the face mimicking golden-hour light through a cozy room. Maintain natural skin texture and soft facial shadows, while enhancing eye catchlights for a vivid, lifelike look. Adjust white balance to a warmer tone, and slightly boost exposure to soften the darker midtones. Preserve the subject's pose and expression, and enhance the depth with gentle background bokeh and subtle filmic glow.",
28
+ "weights": "FLUX.1-Kontext-Cinematic-Relighting.safetensors"
29
+ },
30
+ ]
31
+ print(f"Loaded {len(flux_loras_raw)} LoRAs")
32
+ # Global variables for LoRA management
33
+ current_lora = None
34
+ lora_cache = {}
35
 
36
+ def load_lora_weights(repo_id, weights_filename):
37
+ """Load LoRA weights from HuggingFace"""
38
+ try:
39
+ # First try with the specified filename
40
+ try:
41
+ lora_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
42
+ if repo_id not in lora_cache:
43
+ lora_cache[repo_id] = lora_path
44
+ return lora_path
45
+ except Exception as e:
46
+ print(f"Failed to load {weights_filename}, trying to find alternative LoRA files...")
47
+
48
+ # If the specified file doesn't exist, try to find any .safetensors file
49
+ from huggingface_hub import list_repo_files
50
+ try:
51
+ files = list_repo_files(repo_id)
52
+ safetensors_files = [f for f in files if f.endswith(('.safetensors', '.bin')) and 'lora' in f.lower()]
53
+
54
+ if not safetensors_files:
55
+ # Try without 'lora' in filename
56
+ safetensors_files = [f for f in files if f.endswith('.safetensors')]
57
+
58
+ if safetensors_files:
59
+ # Try the first available file
60
+ for file in safetensors_files:
61
+ try:
62
+ print(f"Trying alternative file: {file}")
63
+ lora_path = hf_hub_download(repo_id=repo_id, filename=file)
64
+ if repo_id not in lora_cache:
65
+ lora_cache[repo_id] = lora_path
66
+ print(f"Successfully loaded alternative LoRA file: {file}")
67
+ return lora_path
68
+ except:
69
+ continue
70
+
71
+ print(f"No suitable LoRA files found in {repo_id}")
72
+ return None
73
+
74
+ except Exception as list_error:
75
+ print(f"Error listing files in repo {repo_id}: {list_error}")
76
+ return None
77
+
78
+ except Exception as e:
79
+ print(f"Error loading LoRA from {repo_id}: {e}")
80
+ return None
81
 
82
+ def update_selection(selected_state: gr.SelectData, flux_loras):
83
+ """Update UI when a LoRA is selected"""
84
+ if selected_state.index >= len(flux_loras):
85
+ return "### No LoRA selected", gr.update(), None
86
+
87
+ lora = flux_loras[selected_state.index]
88
+ lora_title = lora["title"]
89
+ lora_repo = lora["repo"]
90
+ trigger_word = lora["trigger_word"]
91
+
92
+ # Create a more informative selected text
93
+ updated_text = f"### 🎨 Selected Style: {lora_title}"
94
+ new_placeholder = f"Describe additional details, e.g., 'wearing a red hat' or 'smiling'"
95
+
96
+ return updated_text, gr.update(placeholder=new_placeholder), selected_state.index
97
 
98
+ def get_huggingface_lora(link):
99
+ """Download LoRA from HuggingFace link"""
100
+ split_link = link.split("/")
101
+ if len(split_link) == 2:
102
+ try:
103
+ model_card = ModelCard.load(link)
104
+ trigger_word = model_card.data.get("instance_prompt", "")
105
+
106
+ # Try to find the correct safetensors file
107
+ files = list_repo_files(link)
108
+ safetensors_files = [f for f in files if f.endswith('.safetensors')]
109
+
110
+ # Prioritize files with 'lora' in the name
111
+ lora_files = [f for f in safetensors_files if 'lora' in f.lower()]
112
+ if lora_files:
113
+ safetensors_file = lora_files[0]
114
+ elif safetensors_files:
115
+ safetensors_file = safetensors_files[0]
116
+ else:
117
+ # Try .bin files as fallback
118
+ bin_files = [f for f in files if f.endswith('.bin') and 'lora' in f.lower()]
119
+ if bin_files:
120
+ safetensors_file = bin_files[0]
121
+ else:
122
+ safetensors_file = "pytorch_lora_weights.safetensors" # Default fallback
123
+
124
+ print(f"Found LoRA file: {safetensors_file} in {link}")
125
+ return split_link[1], safetensors_file, trigger_word
126
+
127
+ except Exception as e:
128
+ print(f"Error in get_huggingface_lora: {e}")
129
+ # Try basic detection
130
+ try:
131
+ files = list_repo_files(link)
132
+ safetensors_file = next((f for f in files if f.endswith('.safetensors')), "pytorch_lora_weights.safetensors")
133
+ return split_link[1], safetensors_file, ""
134
+ except:
135
+ raise Exception(f"Error loading LoRA: {e}")
136
+ else:
137
+ raise Exception("Invalid HuggingFace repository format")
138
 
139
+ def load_custom_lora(link):
140
+ """Load custom LoRA from user input"""
141
+ if not link:
142
+ return gr.update(visible=False), "", gr.update(visible=False), None, gr.Gallery(selected_index=None), "### 🎨 Select an art style from the gallery", None
143
+
144
+ try:
145
+ repo_name, weights_file, trigger_word = get_huggingface_lora(link)
146
+
147
+ card = f'''
148
+ <div class="custom_lora_card">
149
+ <div style="display: flex; align-items: center; margin-bottom: 12px;">
150
+ <span style="font-size: 18px; margin-right: 8px;">✅</span>
151
+ <strong style="font-size: 16px;">Custom LoRA Loaded!</strong>
152
+ </div>
153
+ <div style="background: rgba(255, 255, 255, 0.8); padding: 12px; border-radius: 8px;">
154
+ <h4 style="margin: 0 0 8px 0; color: #333;">{repo_name}</h4>
155
+ <small style="color: #666;">{"Trigger: <code style='background: #f0f0f0; padding: 2px 6px; border-radius: 4px;'><b>"+trigger_word+"</b></code>" if trigger_word else "No trigger word found"}</small>
156
+ </div>
157
+ </div>
158
+ '''
159
+
160
+ custom_lora_data = {
161
+ "repo": link,
162
+ "weights": weights_file,
163
+ "trigger_word": trigger_word
164
+ }
165
+
166
+ return gr.update(visible=True), card, gr.update(visible=True), custom_lora_data, gr.Gallery(selected_index=None), f"🎨 Custom Style: {repo_name}", None
167
+
168
+ except Exception as e:
169
+ return gr.update(visible=True), f"Error: {str(e)}", gr.update(visible=False), None, gr.update(), "### 🎨 Select an art style from the gallery", None
170
 
171
+ def remove_custom_lora():
172
+ """Remove custom LoRA"""
173
+ return "", gr.update(visible=False), gr.update(visible=False), None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
+ def classify_gallery(flux_loras):
176
+ """Sort gallery by likes"""
177
+ try:
178
+ sorted_gallery = sorted(flux_loras, key=lambda x: x.get("likes", 0), reverse=True)
179
+ gallery_items = []
180
+
181
+ for item in sorted_gallery:
182
+ if "image" in item and "title" in item:
183
+ image_path = item["image"]
184
+ title = item["title"]
185
+
186
+ # Simply use the path as-is for Gradio to handle
187
+ gallery_items.append((image_path, title))
188
+ print(f"Added to gallery: {image_path} - {title}")
189
+
190
+ print(f"Total gallery items: {len(gallery_items)}")
191
+ return gallery_items, sorted_gallery
192
+ except Exception as e:
193
+ print(f"Error in classify_gallery: {e}")
194
+ import traceback
195
+ traceback.print_exc()
196
+ return [], []
197
 
198
+ def infer_with_lora_wrapper(input_image, prompt, selected_index, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.0, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
199
+ """Wrapper function to handle state serialization"""
200
+ return infer_with_lora(input_image, prompt, selected_index, custom_lora, seed, randomize_seed, guidance_scale, lora_scale, flux_loras, progress)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
 
202
  @spaces.GPU
203
+ def infer_with_lora(input_image, prompt, selected_index, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.0, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
204
+ """Generate image with selected LoRA"""
205
+ global current_lora, pipe
206
+
207
+ # Check if input image is provided
208
+ if input_image is None:
209
+ gr.Warning("Please upload your portrait photo first! 📸")
210
+ return None, seed, gr.update(visible=False)
211
+
212
+ if randomize_seed:
213
+ seed = random.randint(0, MAX_SEED)
214
+
215
+ # Determine which LoRA to use
216
+ lora_to_use = None
217
+ if custom_lora:
218
+ lora_to_use = custom_lora
219
+ elif selected_index is not None and flux_loras and selected_index < len(flux_loras):
220
+ lora_to_use = flux_loras[selected_index]
221
+ # Load LoRA if needed
222
+ if lora_to_use and lora_to_use != current_lora:
223
+ try:
224
+ # Unload current LoRA
225
+ if current_lora:
226
+ pipe.unload_lora_weights()
227
+ print(f"Unloaded previous LoRA")
228
+
229
+ # Load new LoRA
230
+ repo_id = lora_to_use.get("repo", "unknown")
231
+ weights_file = lora_to_use.get("weights", "pytorch_lora_weights.safetensors")
232
+ print(f"Loading LoRA: {repo_id} with weights: {weights_file}")
233
+
234
+ lora_path = load_lora_weights(repo_id, weights_file)
235
+ if lora_path:
236
+ pipe.load_lora_weights(lora_path, adapter_name="selected_lora")
237
+ pipe.set_adapters(["selected_lora"], adapter_weights=[lora_scale])
238
+ print(f"Successfully loaded: {lora_path} with scale {lora_scale}")
239
+ current_lora = lora_to_use
240
+ else:
241
+ print(f"Failed to load LoRA from {repo_id}")
242
+ gr.Warning(f"Failed to load {lora_to_use.get('title', 'style')}. Please try a different art style.")
243
+ return None, seed, gr.update(visible=False)
244
+
245
+ except Exception as e:
246
+ print(f"Error loading LoRA: {e}")
247
+ # Continue without LoRA
248
  else:
249
+ if lora_to_use:
250
+ print(f"Using already loaded LoRA: {lora_to_use.get('repo', 'unknown')}")
251
+
252
+ try:
253
+ # Convert image to RGB
254
+ input_image = input_image.convert("RGB")
255
+ except Exception as e:
256
+ print(f"Error processing image: {e}")
257
+ gr.Warning("Error processing the uploaded image. Please try a different photo. 📸")
258
+ return None, seed, gr.update(visible=False)
259
+
260
+ # Check if LoRA is selected
261
+ if lora_to_use is None:
262
+ gr.Warning("Please select an art style from the gallery first! 🎨")
263
+ return None, seed, gr.update(visible=False)
264
+
265
+ # Add trigger word to prompt
266
+ trigger_word = lora_to_use.get("trigger_word", "")
267
+
268
+ # Special handling for different art styles
269
+ if trigger_word == "ghibli":
270
+ prompt = f"Create a Studio Ghibli anime style portrait of the person in the photo, {prompt}. Maintain the facial identity while transforming into whimsical anime art style."
271
+ elif trigger_word == "homer":
272
+ prompt = f"Paint the person in Winslow Homer's American realist style, {prompt}. Keep facial features while applying watercolor and marine art techniques."
273
+ elif trigger_word == "gogh":
274
+ prompt = f"Transform the portrait into Van Gogh's post-impressionist style with swirling brushstrokes, {prompt}. Maintain facial identity with expressive colors."
275
+ elif trigger_word == "Cezanne":
276
+ prompt = f"Render the person in Paul Cézanne's geometric post-impressionist style, {prompt}. Keep facial structure while applying structured brushwork."
277
+ elif trigger_word == "Renoir":
278
+ prompt = f"Paint the portrait in Pierre-Auguste Renoir's impressionist style with soft light, {prompt}. Maintain identity with luminous skin tones."
279
+ elif trigger_word == "claude monet":
280
+ prompt = f"Create an impressionist portrait in Claude Monet's style with visible brushstrokes, {prompt}. Keep facial features while using light and color."
281
+ elif trigger_word == "fantasy":
282
+ prompt = f"Transform into an epic fantasy character portrait, {prompt}. Maintain facial identity while adding magical and fantastical elements."
283
+ elif trigger_word == ", How2Draw":
284
+ prompt = f"create a How2Draw sketch of the person of the photo {prompt}, maintain the facial identity of the person and general features"
285
+ elif trigger_word == ", video game screenshot in the style of THSMS":
286
+ prompt = f"create a video game screenshot in the style of THSMS with the person from the photo, {prompt}. maintain the facial identity of the person and general features"
287
+ else:
288
+ prompt = f"convert the style of this portrait photo to {trigger_word} while maintaining the identity of the person. {prompt}. Make sure to maintain the person's facial identity and features, while still changing the overall style to {trigger_word}."
289
+
290
+ try:
291
+ image = pipe(
292
+ image=input_image,
293
+ prompt=prompt,
294
+ guidance_scale=guidance_scale,
295
+ generator=torch.Generator().manual_seed(seed),
296
+ ).images[0]
297
+
298
+ return image, seed, gr.update(visible=True)
299
+
300
+ except Exception as e:
301
+ print(f"Error during inference: {e}")
302
+ return None, seed, gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
303
 
304
+ # CSS styling with beautiful gradient pastel design
305
+ css = '''
306
+ #gen_btn{height: 100%}
307
+ #gen_column{align-self: stretch}
308
+ #title{text-align: center}
309
+ #title h1{font-size: 3em; display:inline-flex; align-items:center}
310
+ #title img{width: 100px; margin-right: 0.5em}
311
+ #gallery .grid-wrap{height: 10vh}
312
+ #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
313
+ .card_internal{display: flex;height: 100px;margin-top: .5em}
314
+ .card_internal img{margin-right: 1em}
315
+ .styler{--form-gap-width: 0px !important}
316
+ #progress{height:30px}
317
+ #progress .generating{display:none}
318
+ .progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
319
+ .progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
320
+ '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
+ # Create Gradio interface
323
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
324
+ gr_flux_loras = gr.State(value=flux_loras_raw)
325
+
326
+ title = gr.HTML(
327
+ """<h1>Flux Kontext DLC 🎈</h1>""",
328
  )
329
+
330
+ selected_state = gr.State(value=None)
331
+ custom_loaded_lora = gr.State(value=None)
332
+
333
+ with gr.Row(elem_id="main_app"):
334
+ with gr.Column(scale=4, elem_id="box_column"):
335
+ with gr.Group(elem_id="gallery_box"):
336
+ input_image = gr.Image(label="Upload an image for editing", type="pil", height=260)
337
+
338
+ gallery = gr.Gallery(
339
+ label="Choose the Flux Kontext LoRA",
340
+ allow_preview=False,
341
+ columns=3,
342
+ elem_id="gallery",
343
+ show_share_button=False,
344
+ height=400
345
+ )
346
+
347
+ custom_model = gr.Textbox(
348
+ label="🔗 Or use a custom LoRA from HuggingFace",
349
+ placeholder="e.g., username/lora-name",
350
+ visible=True
351
+ )
352
+ custom_model_card = gr.HTML(visible=False)
353
+ custom_model_button = gr.Button("Remove custom LoRA", visible=False)
354
+
355
+ with gr.Column(scale=5):
356
+ with gr.Row():
357
+ prompt = gr.Textbox(
358
+ label="Additional Details (optional)",
359
+ show_label=False,
360
+ lines=1,
361
+ max_lines=1,
362
+ placeholder="Describe additional details, e.g., 'wearing a red hat' or 'smiling'",
363
+ elem_id="prompt"
364
+ )
365
+ run_button = gr.Button("Edit Image", elem_id="run_button")
366
+
367
+ result = gr.Image(label="Your Kontext Edited Image", interactive=False)
368
+ reuse_button = gr.Button("Reuse this image", visible=False)
369
+
370
+ with gr.Accordion("Advanced Settings", open=False):
371
+ lora_scale = gr.Slider(
372
+ label="Style Strength",
373
+ minimum=0,
374
+ maximum=2,
375
+ step=0.1,
376
+ value=1.0,
377
+ info="How strongly to apply the art style (1.0 = balanced)"
378
+ )
379
+ seed = gr.Slider(
380
+ label="Random Seed",
381
+ minimum=0,
382
+ maximum=MAX_SEED,
383
+ step=1,
384
+ value=0,
385
+ info="Set to 0 for random results"
386
+ )
387
+ randomize_seed = gr.Checkbox(label="Randomize seed for each generation", value=True)
388
+ guidance_scale = gr.Slider(
389
+ label="Image Guidance",
390
+ minimum=1,
391
+ maximum=10,
392
+ step=0.1,
393
+ value=2.5,
394
+ info="How closely to follow the input image (lower = more creative)"
395
+ )
396
+
397
+ prompt_title = gr.Markdown(
398
+ value="### Select an art style from the gallery",
399
+ visible=True,
400
+ elem_id="selected_lora",
401
+ )
402
 
403
+ # Event handlers
404
+ custom_model.input(
405
+ fn=load_custom_lora,
406
+ inputs=[custom_model],
407
+ outputs=[custom_model_card, custom_model_card, custom_model_button, custom_loaded_lora, gallery, prompt_title, selected_state],
408
+ )
409
+
410
+ custom_model_button.click(
411
+ fn=remove_custom_lora,
412
+ outputs=[custom_model, custom_model_button, custom_model_card, custom_loaded_lora, selected_state]
413
+ )
414
+
415
+ gallery.select(
416
+ fn=update_selection,
417
+ inputs=[gr_flux_loras],
418
+ outputs=[prompt_title, prompt, selected_state],
419
+ show_progress=False
420
+ )
421
+
422
+ gr.on(
423
+ triggers=[run_button.click, prompt.submit],
424
+ fn=infer_with_lora_wrapper,
425
+ inputs=[input_image, prompt, selected_state, custom_loaded_lora, seed, randomize_seed, guidance_scale, lora_scale, gr_flux_loras],
426
+ outputs=[result, seed, reuse_button]
427
+ )
428
+
429
+ reuse_button.click(
430
+ fn=lambda image: image,
431
+ inputs=[result],
432
+ outputs=[input_image]
433
+ )
434
 
435
+ demo.load(
436
+ fn=classify_gallery,
437
+ inputs=[gr_flux_loras],
438
+ outputs=[gallery, gr_flux_loras]
439
+ )
 
 
 
 
 
 
 
 
 
 
440
 
441
+ demo.queue(default_concurrency_limit=None)
442
+ demo.launch()