prithivMLmods commited on
Commit
0f2e032
·
verified ·
1 Parent(s): e1299f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +453 -339
app.py CHANGED
@@ -1,360 +1,474 @@
1
- import os
2
- import random
3
- import uuid
4
- import json
5
- import time
6
- import asyncio
7
- from threading import Thread
8
-
9
  import gradio as gr
 
10
  import spaces
11
  import torch
12
- import numpy as np
 
 
13
  from PIL import Image
14
- import cv2
15
-
16
- from transformers import (
17
- Qwen2_5_VLForConditionalGeneration,
18
- AutoModelForCausalLM,
19
- AutoProcessor,
20
- TextIteratorStreamer,
21
- )
22
- from transformers.image_utils import load_image
23
-
24
- # Constants for text generation
25
- MAX_MAX_NEW_TOKENS = 2048
26
- DEFAULT_MAX_NEW_TOKENS = 1024
27
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
28
 
29
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
30
 
31
- # Load Camel-Doc-OCR-080125
32
- MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-080125"
33
- processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
34
- model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
35
- MODEL_ID_M, trust_remote_code=True,
36
- torch_dtype=torch.float16).to(device).eval()
37
 
38
- # Load OCRFlux-3B
39
- MODEL_ID_X = "ChatDOC/OCRFlux-3B"
40
- processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
41
- model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
42
- MODEL_ID_X, trust_remote_code=True,
43
- torch_dtype=torch.float16).to(device).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- # Load Behemoth-3B-070225
46
- MODEL_ID_T = "prithivMLmods/Behemoth-3B-070225-post0.1"
47
- processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
48
- model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
49
- MODEL_ID_T, trust_remote_code=True,
50
- torch_dtype=torch.float16).to(device).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- # Load MonkeyOCR-pro-1.2B
53
- MODEL_ID_O = "echo840/MonkeyOCR-pro-1.2B"
54
- SUBFOLDER = "Recognition"
55
- processor_o = AutoProcessor.from_pretrained(MODEL_ID_O, trust_remote_code=True, subfolder=SUBFOLDER)
56
- model_o = Qwen2_5_VLForConditionalGeneration.from_pretrained(
57
- MODEL_ID_O, trust_remote_code=True, subfolder=SUBFOLDER,
58
- torch_dtype=torch.float16).to(device).eval()
 
 
 
 
 
 
 
 
59
 
60
- # Load ViGoRL-MCTS-SFT-7b-Spatial
61
- MODEL_ID_A = "gsarch/ViGoRL-MCTS-SFT-7b-Spatial"
62
- processor_a = AutoProcessor.from_pretrained(MODEL_ID_A, trust_remote_code=True)
63
- model_a = Qwen2_5_VLForConditionalGeneration.from_pretrained(
64
- MODEL_ID_A, trust_remote_code=True,
65
- torch_dtype=torch.float16).to(device).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- # Function to downsample video frames
68
- def downsample_video(video_path):
69
- """
70
- Downsamples the video to evenly spaced frames.
71
- Each frame is returned as a PIL image along with its timestamp.
72
- """
73
- vidcap = cv2.VideoCapture(video_path)
74
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
75
- fps = vidcap.get(cv2.CAP_PROP_FPS)
76
- frames = []
77
- frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
78
- for i in frame_indices:
79
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
80
- success, image = vidcap.read()
81
- if success:
82
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
83
- pil_image = Image.fromarray(image)
84
- timestamp = round(i / fps, 2)
85
- frames.append((pil_image, timestamp))
86
- vidcap.release()
87
- return frames
 
 
 
 
 
 
 
 
 
 
88
 
89
- # Function to generate text responses based on image input
90
- @spaces.GPU
91
- def generate_image(model_name: str,
92
- text: str,
93
- image: Image.Image,
94
- max_new_tokens: int = 1024,
95
- temperature: float = 0.6,
96
- top_p: float = 0.9,
97
- top_k: int = 50,
98
- repetition_penalty: float = 1.2):
99
- """
100
- Generates responses using the selected model for image input.
101
- """
102
- if model_name == "Camel-Doc-OCR-080125(v2)":
103
- processor = processor_m
104
- model = model_m
105
- elif model_name == "OCRFlux-3B":
106
- processor = processor_x
107
- model = model_x
108
- elif model_name == "Behemoth-3B-070225":
109
- processor = processor_o
110
- model = model_o
111
- elif model_name == "MonkeyOCR-pro-1.2B":
112
- processor = processor_t
113
- model = model_t
114
- elif model_name == "ViGoRL-MCTS-SFT-7B":
115
- processor = processor_a
116
- model = model_a
117
- else:
118
- yield "Invalid model selected.", "Invalid model selected."
119
- return
120
 
121
- if image is None:
122
- yield "Please upload an image.", "Please upload an image."
123
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- messages = [{
126
- "role": "user",
127
- "content": [
128
- {"type": "image", "image": image},
129
- {"type": "text", "text": text},
130
- ]
131
- }]
132
- prompt_full = processor.apply_chat_template(messages,
133
- tokenize=False,
134
- add_generation_prompt=True)
135
- inputs = processor(text=[prompt_full],
136
- images=[image],
137
- return_tensors="pt",
138
- padding=True,
139
- truncation=False,
140
- max_length=MAX_INPUT_TOKEN_LENGTH).to(device)
141
- streamer = TextIteratorStreamer(processor,
142
- skip_prompt=True,
143
- skip_special_tokens=True)
144
- generation_kwargs = {
145
- **inputs, "streamer": streamer,
146
- "max_new_tokens": max_new_tokens
147
- }
148
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
149
- thread.start()
150
- buffer = ""
151
- for new_text in streamer:
152
- buffer += new_text
153
- time.sleep(0.01)
154
- yield buffer, buffer
155
 
156
- # Function to generate text responses based on video input
157
  @spaces.GPU
158
- def generate_video(model_name: str,
159
- text: str,
160
- video_path: str,
161
- max_new_tokens: int = 1024,
162
- temperature: float = 0.6,
163
- top_p: float = 0.9,
164
- top_k: int = 50,
165
- repetition_penalty: float = 1.2):
166
- """
167
- Generates responses using the selected model for video input.
168
- """
169
- if model_name == "Camel-Doc-OCR-080125(v2)":
170
- processor = processor_m
171
- model = model_m
172
- elif model_name == "OCRFlux-3B":
173
- processor = processor_x
174
- model = model_x
175
- elif model_name == "Behemoth-3B-070225":
176
- processor = processor_o
177
- model = model_o
178
- elif model_name == "MonkeyOCR-pro-1.2B":
179
- processor = processor_t
180
- model = model_t
181
- elif model_name == "ViGoRL-MCTS-SFT-7B":
182
- processor = processor_a
183
- model = model_a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  else:
185
- yield "Invalid model selected.", "Invalid model selected."
186
- return
187
-
188
- if video_path is None:
189
- yield "Please upload a video.", "Please upload a video."
190
- return
191
-
192
- frames = downsample_video(video_path)
193
- messages = [{
194
- "role": "system",
195
- "content": [{"type": "text", "text": "You are a helpful assistant."}]
196
- }, {
197
- "role": "user",
198
- "content": [{"type": "text", "text": text}]
199
- }]
200
- for frame in frames:
201
- image, timestamp = frame
202
- messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
203
- messages[1]["content"].append({"type": "image", "image": image})
204
- inputs = processor.apply_chat_template(
205
- messages,
206
- tokenize=True,
207
- add_generation_prompt=True,
208
- return_dict=True,
209
- return_tensors="pt",
210
- truncation=False,
211
- max_length=MAX_INPUT_TOKEN_LENGTH).to(device)
212
- streamer = TextIteratorStreamer(processor,
213
- skip_prompt=True,
214
- skip_special_tokens=True)
215
- generation_kwargs = {
216
- **inputs,
217
- "streamer": streamer,
218
- "max_new_tokens": max_new_tokens,
219
- "do_sample": True,
220
- "temperature": temperature,
221
- "top_p": top_p,
222
- "top_k": top_k,
223
- "repetition_penalty": repetition_penalty,
224
- }
225
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
226
- thread.start()
227
- buffer = ""
228
- for new_text in streamer:
229
- buffer += new_text
230
- buffer = buffer.replace("<|im_end|>", "")
231
- time.sleep(0.01)
232
- yield buffer, buffer
233
-
234
- # Define examples for image and video inference
235
- image_examples = [
236
- ["Explain the essence of the image.", "assets/images/B.jpg"],
237
- ["Extract the content.", "assets/images/1.png"],
238
- ["Describe the safety of the action shown in the image.", "assets/images/C.jpg"],
239
- ["Caption the image.", "assets/images/A.jpg"],
240
- ["Make this into a table for the README.md file.", "assets/images/2.jpg"],
241
- ["Extract the table content from the image.", "assets/images/3.png"],
242
- ["Perform OCR on the image.", "assets/images/4.jpg"]
243
- ]
244
-
245
- video_examples = [
246
- ["Explain the video in detail.", "assets/videos/a.mp4"],
247
- ["Explain the video in detail.", "assets/videos/b.mp4"]
248
- ]
249
-
250
- #css
251
- css = """
252
- .submit-btn {
253
- background-color: #2980b9 !important;
254
- color: white !important;
255
- }
256
- .submit-btn:hover {
257
- background-color: #3498db !important;
258
- }
259
- .canvas-output {
260
- border: 2px solid #4682B4;
261
- border-radius: 10px;
262
- padding: 20px;
263
- }
264
- """
265
 
266
- # Create the Gradio Interface
267
- with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
268
- gr.Markdown(
269
- "# **[Multimodal OCR Outpost](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**"
 
 
270
  )
271
- with gr.Row():
272
- with gr.Column():
273
- with gr.Tabs():
274
- with gr.TabItem("Image Inference"):
275
- image_query = gr.Textbox(
276
- label="Query Input",
277
- placeholder="Enter your query here...")
278
- image_upload = gr.Image(type="pil", label="Image")
279
- image_submit = gr.Button("Submit",
280
- elem_classes="submit-btn")
281
- gr.Examples(examples=image_examples,
282
- inputs=[image_query, image_upload])
283
- with gr.TabItem("Video Inference"):
284
- video_query = gr.Textbox(
285
- label="Query Input",
286
- placeholder="Enter your query here...")
287
- video_upload = gr.Video(label="Video")
288
- video_submit = gr.Button("Submit",
289
- elem_classes="submit-btn")
290
- gr.Examples(examples=video_examples,
291
- inputs=[video_query, video_upload])
292
-
293
- with gr.Accordion("Advanced options", open=False):
294
- max_new_tokens = gr.Slider(label="Max new tokens",
295
- minimum=1,
296
- maximum=MAX_MAX_NEW_TOKENS,
297
- step=1,
298
- value=DEFAULT_MAX_NEW_TOKENS)
299
- temperature = gr.Slider(label="Temperature",
300
- minimum=0.1,
301
- maximum=4.0,
302
- step=0.1,
303
- value=0.6)
304
- top_p = gr.Slider(label="Top-p (nucleus sampling)",
305
- minimum=0.05,
306
- maximum=1.0,
307
- step=0.05,
308
- value=0.9)
309
- top_k = gr.Slider(label="Top-k",
310
- minimum=1,
311
- maximum=1000,
312
- step=1,
313
- value=50)
314
- repetition_penalty = gr.Slider(label="Repetition penalty",
315
- minimum=1.0,
316
- maximum=2.0,
317
- step=0.05,
318
- value=1.2)
319
-
320
- with gr.Column():
321
- with gr.Column(elem_classes="canvas-output"):
322
- gr.Markdown("## Output")
323
- output = gr.Textbox(label="Raw Output Stream",
324
- interactive=False,
325
- lines=2, show_copy_button=True)
326
- with gr.Accordion("(Result.md)", open=False):
327
- markdown_output = gr.Markdown(
328
- label="markup.md")
329
-
330
- model_choice = gr.Radio(choices=[
331
- "Camel-Doc-OCR-080125(v2)", "OCRFlux-3B",
332
- "ViGoRL-MCTS-SFT-7B", "Behemoth-3B-070225",
333
- "MonkeyOCR-pro-1.2B"],
334
- label="Select Model",
335
- value="Camel-Doc-OCR-080125(v2)")
336
 
337
- gr.Markdown("**Model Info 💻** | [Report Bug](https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR-Outpost/discussions)")
338
- gr.Markdown("> Camel-Doc-OCR-080125 is a specialized vision-language model, fine-tuned from Qwen2.5-VL-7B-Instruct, and excels at document retrieval, content extraction, and analysis recognition for both structured and unstructured digital documents. OCRFlux-3B is a 3B-parameter vision-language model optimized for high-quality OCR on PDFs and images, excelling in converting documents to clean Markdown text and supporting features like cross-page table/paragraph merging.")
339
- gr.Markdown("> Both ViGoRL-MCTS-SFT-3b-Spatial and 7b-Spatial are vision-language models that use multi-turn visually grounded reinforcement learning for precise spatial reasoning and visual grounding, with the 3b and 7b variants differing mainly in their architectural size for fine-grained visual tasks.")
340
- gr.Markdown("> Behemoth-3B-070225-post0.1 is an advanced 3B parameter model tailored for extensive multimodal comprehension, document parsing, and possibly generalized OCR/vision-language tasks. MonkeyOCR-pro-1.2B is a lightweight OCR model focusing on high-accuracy text extraction from images and scanned documents, suitable for resource-constrained environments.")
341
- gr.Markdown("> ⚠️ Note: Models in this space may not perform well on video inference tasks.")
342
 
343
- # Define the submit button actions
344
- image_submit.click(fn=generate_image,
345
- inputs=[
346
- model_choice, image_query, image_upload,
347
- max_new_tokens, temperature, top_p, top_k,
348
- repetition_penalty
349
- ],
350
- outputs=[output, markdown_output])
351
- video_submit.click(fn=generate_video,
352
- inputs=[
353
- model_choice, video_query, video_upload,
354
- max_new_tokens, temperature, top_p, top_k,
355
- repetition_penalty
356
- ],
357
- outputs=[output, markdown_output])
358
-
359
- if __name__ == "__main__":
360
- demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import numpy as np
3
  import spaces
4
  import torch
5
+ import random
6
+ import json
7
+ import os
8
  from PIL import Image
9
+ from diffusers import FluxKontextPipeline
10
+ from diffusers.utils import load_image
11
+ from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, list_repo_files
12
+ from safetensors.torch import load_file
13
+ import requests
14
+ import re
 
 
 
 
 
 
 
 
15
 
16
+ # Load Kontext model
17
+ MAX_SEED = np.iinfo(np.int32).max
18
 
19
+ pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
 
 
 
 
 
20
 
21
+ # Load LoRA data
22
+ flux_loras_raw = [
23
+ {
24
+ "image": "examples/1.png",
25
+ "title": "Studio Ghibli",
26
+ "repo": "openfree/flux-chatgpt-ghibli-lora",
27
+ "trigger_word": "ghibli",
28
+ "weights": "pytorch_lora_weights.safetensors",
29
+ "likes": 0
30
+ },
31
+ {
32
+ "image": "examples/2.png",
33
+ "title": "Winslow Homer",
34
+ "repo": "openfree/winslow-homer",
35
+ "trigger_word": "homer",
36
+ "weights": "pytorch_lora_weights.safetensors",
37
+ "likes": 0
38
+ },
39
+ {
40
+ "image": "examples/3.png",
41
+ "title": "Van Gogh",
42
+ "repo": "openfree/van-gogh",
43
+ "trigger_word": "gogh",
44
+ "weights": "pytorch_lora_weights.safetensors",
45
+ "likes": 0
46
+ },
47
+ {
48
+ "image": "examples/4.png",
49
+ "title": "Paul Cézanne",
50
+ "repo": "openfree/paul-cezanne",
51
+ "trigger_word": "Cezanne",
52
+ "weights": "pytorch_lora_weights.safetensors",
53
+ "likes": 0
54
+ },
55
+ {
56
+ "image": "examples/5.png",
57
+ "title": "Renoir",
58
+ "repo": "openfree/pierre-auguste-renoir",
59
+ "trigger_word": "Renoir",
60
+ "weights": "pytorch_lora_weights.safetensors",
61
+ "likes": 0
62
+ },
63
+ {
64
+ "image": "examples/6.png",
65
+ "title": "Claude Monet",
66
+ "repo": "openfree/claude-monet",
67
+ "trigger_word": "claude monet",
68
+ "weights": "pytorch_lora_weights.safetensors",
69
+ "likes": 0
70
+ },
71
+ {
72
+ "image": "examples/7.png",
73
+ "title": "Fantasy Art",
74
+ "repo": "openfree/myt-flux-fantasy",
75
+ "trigger_word": "fantasy",
76
+ "weights": "pytorch_lora_weights.safetensors",
77
+ "likes": 0
78
+ }
79
+ ]
80
+ print(f"Loaded {len(flux_loras_raw)} LoRAs")
81
+ # Global variables for LoRA management
82
+ current_lora = None
83
+ lora_cache = {}
84
 
85
+ def load_lora_weights(repo_id, weights_filename):
86
+ """Load LoRA weights from HuggingFace"""
87
+ try:
88
+ # First try with the specified filename
89
+ try:
90
+ lora_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
91
+ if repo_id not in lora_cache:
92
+ lora_cache[repo_id] = lora_path
93
+ return lora_path
94
+ except Exception as e:
95
+ print(f"Failed to load {weights_filename}, trying to find alternative LoRA files...")
96
+
97
+ # If the specified file doesn't exist, try to find any .safetensors file
98
+ from huggingface_hub import list_repo_files
99
+ try:
100
+ files = list_repo_files(repo_id)
101
+ safetensors_files = [f for f in files if f.endswith(('.safetensors', '.bin')) and 'lora' in f.lower()]
102
+
103
+ if not safetensors_files:
104
+ # Try without 'lora' in filename
105
+ safetensors_files = [f for f in files if f.endswith('.safetensors')]
106
+
107
+ if safetensors_files:
108
+ # Try the first available file
109
+ for file in safetensors_files:
110
+ try:
111
+ print(f"Trying alternative file: {file}")
112
+ lora_path = hf_hub_download(repo_id=repo_id, filename=file)
113
+ if repo_id not in lora_cache:
114
+ lora_cache[repo_id] = lora_path
115
+ print(f"Successfully loaded alternative LoRA file: {file}")
116
+ return lora_path
117
+ except:
118
+ continue
119
+
120
+ print(f"No suitable LoRA files found in {repo_id}")
121
+ return None
122
+
123
+ except Exception as list_error:
124
+ print(f"Error listing files in repo {repo_id}: {list_error}")
125
+ return None
126
+
127
+ except Exception as e:
128
+ print(f"Error loading LoRA from {repo_id}: {e}")
129
+ return None
130
 
131
+ def update_selection(selected_state: gr.SelectData, flux_loras):
132
+ """Update UI when a LoRA is selected"""
133
+ if selected_state.index >= len(flux_loras):
134
+ return "### No LoRA selected", gr.update(), None
135
+
136
+ lora = flux_loras[selected_state.index]
137
+ lora_title = lora["title"]
138
+ lora_repo = lora["repo"]
139
+ trigger_word = lora["trigger_word"]
140
+
141
+ # Create a more informative selected text
142
+ updated_text = f"### 🎨 Selected Style: {lora_title}"
143
+ new_placeholder = f"Describe additional details, e.g., 'wearing a red hat' or 'smiling'"
144
+
145
+ return updated_text, gr.update(placeholder=new_placeholder), selected_state.index
146
 
147
+ def get_huggingface_lora(link):
148
+ """Download LoRA from HuggingFace link"""
149
+ split_link = link.split("/")
150
+ if len(split_link) == 2:
151
+ try:
152
+ model_card = ModelCard.load(link)
153
+ trigger_word = model_card.data.get("instance_prompt", "")
154
+
155
+ # Try to find the correct safetensors file
156
+ files = list_repo_files(link)
157
+ safetensors_files = [f for f in files if f.endswith('.safetensors')]
158
+
159
+ # Prioritize files with 'lora' in the name
160
+ lora_files = [f for f in safetensors_files if 'lora' in f.lower()]
161
+ if lora_files:
162
+ safetensors_file = lora_files[0]
163
+ elif safetensors_files:
164
+ safetensors_file = safetensors_files[0]
165
+ else:
166
+ # Try .bin files as fallback
167
+ bin_files = [f for f in files if f.endswith('.bin') and 'lora' in f.lower()]
168
+ if bin_files:
169
+ safetensors_file = bin_files[0]
170
+ else:
171
+ safetensors_file = "pytorch_lora_weights.safetensors" # Default fallback
172
+
173
+ print(f"Found LoRA file: {safetensors_file} in {link}")
174
+ return split_link[1], safetensors_file, trigger_word
175
+
176
+ except Exception as e:
177
+ print(f"Error in get_huggingface_lora: {e}")
178
+ # Try basic detection
179
+ try:
180
+ files = list_repo_files(link)
181
+ safetensors_file = next((f for f in files if f.endswith('.safetensors')), "pytorch_lora_weights.safetensors")
182
+ return split_link[1], safetensors_file, ""
183
+ except:
184
+ raise Exception(f"Error loading LoRA: {e}")
185
+ else:
186
+ raise Exception("Invalid HuggingFace repository format")
187
 
188
+ def load_custom_lora(link):
189
+ """Load custom LoRA from user input"""
190
+ if not link:
191
+ return gr.update(visible=False), "", gr.update(visible=False), None, gr.Gallery(selected_index=None), "### 🎨 Select an art style from the gallery", None
192
+
193
+ try:
194
+ repo_name, weights_file, trigger_word = get_huggingface_lora(link)
195
+
196
+ card = f'''
197
+ <div class="custom_lora_card">
198
+ <div style="display: flex; align-items: center; margin-bottom: 12px;">
199
+ <span style="font-size: 18px; margin-right: 8px;">✅</span>
200
+ <strong style="font-size: 16px;">Custom LoRA Loaded!</strong>
201
+ </div>
202
+ <div style="background: rgba(255, 255, 255, 0.8); padding: 12px; border-radius: 8px;">
203
+ <h4 style="margin: 0 0 8px 0; color: #333;">{repo_name}</h4>
204
+ <small style="color: #666;">{"Trigger: <code style='background: #f0f0f0; padding: 2px 6px; border-radius: 4px;'><b>"+trigger_word+"</b></code>" if trigger_word else "No trigger word found"}</small>
205
+ </div>
206
+ </div>
207
+ '''
208
+
209
+ custom_lora_data = {
210
+ "repo": link,
211
+ "weights": weights_file,
212
+ "trigger_word": trigger_word
213
+ }
214
+
215
+ return gr.update(visible=True), card, gr.update(visible=True), custom_lora_data, gr.Gallery(selected_index=None), f"🎨 Custom Style: {repo_name}", None
216
+
217
+ except Exception as e:
218
+ return gr.update(visible=True), f"Error: {str(e)}", gr.update(visible=False), None, gr.update(), "### 🎨 Select an art style from the gallery", None
219
 
220
+ def remove_custom_lora():
221
+ """Remove custom LoRA"""
222
+ return "", gr.update(visible=False), gr.update(visible=False), None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
+ def classify_gallery(flux_loras):
225
+ """Sort gallery by likes"""
226
+ try:
227
+ sorted_gallery = sorted(flux_loras, key=lambda x: x.get("likes", 0), reverse=True)
228
+ gallery_items = []
229
+
230
+ for item in sorted_gallery:
231
+ if "image" in item and "title" in item:
232
+ image_path = item["image"]
233
+ title = item["title"]
234
+
235
+ # Simply use the path as-is for Gradio to handle
236
+ gallery_items.append((image_path, title))
237
+ print(f"Added to gallery: {image_path} - {title}")
238
+
239
+ print(f"Total gallery items: {len(gallery_items)}")
240
+ return gallery_items, sorted_gallery
241
+ except Exception as e:
242
+ print(f"Error in classify_gallery: {e}")
243
+ import traceback
244
+ traceback.print_exc()
245
+ return [], []
246
 
247
+ def infer_with_lora_wrapper(input_image, prompt, selected_index, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.0, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
248
+ """Wrapper function to handle state serialization"""
249
+ return infer_with_lora(input_image, prompt, selected_index, custom_lora, seed, randomize_seed, guidance_scale, lora_scale, flux_loras, progress)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
 
251
  @spaces.GPU
252
+ def infer_with_lora(input_image, prompt, selected_index, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.0, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
253
+ """Generate image with selected LoRA"""
254
+ global current_lora, pipe
255
+
256
+ # Check if input image is provided
257
+ if input_image is None:
258
+ gr.Warning("Please upload your portrait photo first! 📸")
259
+ return None, seed, gr.update(visible=False)
260
+
261
+ if randomize_seed:
262
+ seed = random.randint(0, MAX_SEED)
263
+
264
+ # Determine which LoRA to use
265
+ lora_to_use = None
266
+ if custom_lora:
267
+ lora_to_use = custom_lora
268
+ elif selected_index is not None and flux_loras and selected_index < len(flux_loras):
269
+ lora_to_use = flux_loras[selected_index]
270
+ # Load LoRA if needed
271
+ if lora_to_use and lora_to_use != current_lora:
272
+ try:
273
+ # Unload current LoRA
274
+ if current_lora:
275
+ pipe.unload_lora_weights()
276
+ print(f"Unloaded previous LoRA")
277
+
278
+ # Load new LoRA
279
+ repo_id = lora_to_use.get("repo", "unknown")
280
+ weights_file = lora_to_use.get("weights", "pytorch_lora_weights.safetensors")
281
+ print(f"Loading LoRA: {repo_id} with weights: {weights_file}")
282
+
283
+ lora_path = load_lora_weights(repo_id, weights_file)
284
+ if lora_path:
285
+ pipe.load_lora_weights(lora_path, adapter_name="selected_lora")
286
+ pipe.set_adapters(["selected_lora"], adapter_weights=[lora_scale])
287
+ print(f"Successfully loaded: {lora_path} with scale {lora_scale}")
288
+ current_lora = lora_to_use
289
+ else:
290
+ print(f"Failed to load LoRA from {repo_id}")
291
+ gr.Warning(f"Failed to load {lora_to_use.get('title', 'style')}. Please try a different art style.")
292
+ return None, seed, gr.update(visible=False)
293
+
294
+ except Exception as e:
295
+ print(f"Error loading LoRA: {e}")
296
+ # Continue without LoRA
297
  else:
298
+ if lora_to_use:
299
+ print(f"Using already loaded LoRA: {lora_to_use.get('repo', 'unknown')}")
300
+
301
+ try:
302
+ # Convert image to RGB
303
+ input_image = input_image.convert("RGB")
304
+ except Exception as e:
305
+ print(f"Error processing image: {e}")
306
+ gr.Warning("Error processing the uploaded image. Please try a different photo. 📸")
307
+ return None, seed, gr.update(visible=False)
308
+
309
+ # Check if LoRA is selected
310
+ if lora_to_use is None:
311
+ gr.Warning("Please select an art style from the gallery first! 🎨")
312
+ return None, seed, gr.update(visible=False)
313
+
314
+ # Add trigger word to prompt
315
+ trigger_word = lora_to_use.get("trigger_word", "")
316
+
317
+ # Special handling for different art styles
318
+ if trigger_word == "ghibli":
319
+ prompt = f"Create a Studio Ghibli anime style portrait of the person in the photo, {prompt}. Maintain the facial identity while transforming into whimsical anime art style."
320
+ elif trigger_word == "homer":
321
+ prompt = f"Paint the person in Winslow Homer's American realist style, {prompt}. Keep facial features while applying watercolor and marine art techniques."
322
+ elif trigger_word == "gogh":
323
+ prompt = f"Transform the portrait into Van Gogh's post-impressionist style with swirling brushstrokes, {prompt}. Maintain facial identity with expressive colors."
324
+ elif trigger_word == "Cezanne":
325
+ prompt = f"Render the person in Paul Cézanne's geometric post-impressionist style, {prompt}. Keep facial structure while applying structured brushwork."
326
+ elif trigger_word == "Renoir":
327
+ prompt = f"Paint the portrait in Pierre-Auguste Renoir's impressionist style with soft light, {prompt}. Maintain identity with luminous skin tones."
328
+ elif trigger_word == "claude monet":
329
+ prompt = f"Create an impressionist portrait in Claude Monet's style with visible brushstrokes, {prompt}. Keep facial features while using light and color."
330
+ elif trigger_word == "fantasy":
331
+ prompt = f"Transform into an epic fantasy character portrait, {prompt}. Maintain facial identity while adding magical and fantastical elements."
332
+ elif trigger_word == ", How2Draw":
333
+ prompt = f"create a How2Draw sketch of the person of the photo {prompt}, maintain the facial identity of the person and general features"
334
+ elif trigger_word == ", video game screenshot in the style of THSMS":
335
+ prompt = f"create a video game screenshot in the style of THSMS with the person from the photo, {prompt}. maintain the facial identity of the person and general features"
336
+ else:
337
+ prompt = f"convert the style of this portrait photo to {trigger_word} while maintaining the identity of the person. {prompt}. Make sure to maintain the person's facial identity and features, while still changing the overall style to {trigger_word}."
338
+
339
+ try:
340
+ image = pipe(
341
+ image=input_image,
342
+ prompt=prompt,
343
+ guidance_scale=guidance_scale,
344
+ generator=torch.Generator().manual_seed(seed),
345
+ ).images[0]
346
+
347
+ return image, seed, gr.update(visible=True)
348
+
349
+ except Exception as e:
350
+ print(f"Error during inference: {e}")
351
+ return None, seed, gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
+ # Create Gradio interface
354
+ with gr.Blocks(css=css) as demo:
355
+ gr_flux_loras = gr.State(value=flux_loras_raw)
356
+
357
+ title = gr.HTML(
358
+ """<h1>FLUX Kontex Super LoRAs🖖</h1>""",
359
  )
360
+
361
+ selected_state = gr.State(value=None)
362
+ custom_loaded_lora = gr.State(value=None)
363
+
364
+ with gr.Row(elem_id="main_app"):
365
+ with gr.Column(scale=4, elem_id="box_column"):
366
+ with gr.Group(elem_id="gallery_box"):
367
+ input_image = gr.Image(label="Upload your portrait photo 📸", type="pil", height=300)
368
+
369
+ gallery = gr.Gallery(
370
+ label="Choose Your Art Style",
371
+ allow_preview=False,
372
+ columns=3,
373
+ elem_id="gallery",
374
+ show_share_button=False,
375
+ height=400
376
+ )
377
+
378
+ custom_model = gr.Textbox(
379
+ label="🔗 Or use a custom LoRA from HuggingFace",
380
+ placeholder="e.g., username/lora-name",
381
+ visible=True
382
+ )
383
+ custom_model_card = gr.HTML(visible=False)
384
+ custom_model_button = gr.Button("❌ Remove custom LoRA", visible=False)
385
+
386
+ with gr.Column(scale=5):
387
+ with gr.Row():
388
+ prompt = gr.Textbox(
389
+ label="Additional Details (optional)",
390
+ show_label=False,
391
+ lines=1,
392
+ max_lines=1,
393
+ placeholder="Describe additional details, e.g., 'wearing a red hat' or 'smiling'",
394
+ elem_id="prompt"
395
+ )
396
+ run_button = gr.Button("Generate ✨", elem_id="run_button")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
 
398
+ result = gr.Image(label="Your Artistic Portrait", interactive=False)
399
+ reuse_button = gr.Button("🔄 Reuse this image", visible=False)
 
 
 
400
 
401
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
402
+ lora_scale = gr.Slider(
403
+ label="Style Strength",
404
+ minimum=0,
405
+ maximum=2,
406
+ step=0.1,
407
+ value=1.0,
408
+ info="How strongly to apply the art style (1.0 = balanced)"
409
+ )
410
+ seed = gr.Slider(
411
+ label="Random Seed",
412
+ minimum=0,
413
+ maximum=MAX_SEED,
414
+ step=1,
415
+ value=0,
416
+ info="Set to 0 for random results"
417
+ )
418
+ randomize_seed = gr.Checkbox(label="🎲 Randomize seed for each generation", value=True)
419
+ guidance_scale = gr.Slider(
420
+ label="Image Guidance",
421
+ minimum=1,
422
+ maximum=10,
423
+ step=0.1,
424
+ value=2.5,
425
+ info="How closely to follow the input image (lower = more creative)"
426
+ )
427
+
428
+ prompt_title = gr.Markdown(
429
+ value="### 🎨 Select an art style from the gallery",
430
+ visible=True,
431
+ elem_id="selected_lora",
432
+ )
433
+
434
+ # Event handlers
435
+ custom_model.input(
436
+ fn=load_custom_lora,
437
+ inputs=[custom_model],
438
+ outputs=[custom_model_card, custom_model_card, custom_model_button, custom_loaded_lora, gallery, prompt_title, selected_state],
439
+ )
440
+
441
+ custom_model_button.click(
442
+ fn=remove_custom_lora,
443
+ outputs=[custom_model, custom_model_button, custom_model_card, custom_loaded_lora, selected_state]
444
+ )
445
+
446
+ gallery.select(
447
+ fn=update_selection,
448
+ inputs=[gr_flux_loras],
449
+ outputs=[prompt_title, prompt, selected_state],
450
+ show_progress=False
451
+ )
452
+
453
+ gr.on(
454
+ triggers=[run_button.click, prompt.submit],
455
+ fn=infer_with_lora_wrapper,
456
+ inputs=[input_image, prompt, selected_state, custom_loaded_lora, seed, randomize_seed, guidance_scale, lora_scale, gr_flux_loras],
457
+ outputs=[result, seed, reuse_button]
458
+ )
459
+
460
+ reuse_button.click(
461
+ fn=lambda image: image,
462
+ inputs=[result],
463
+ outputs=[input_image]
464
+ )
465
+
466
+ # Initialize gallery
467
+ demo.load(
468
+ fn=classify_gallery,
469
+ inputs=[gr_flux_loras],
470
+ outputs=[gallery, gr_flux_loras]
471
+ )
472
+
473
+ demo.queue(default_concurrency_limit=None)
474
+ demo.launch()