ginipick commited on
Commit
78f8a2a
·
verified ·
1 Parent(s): 24c51c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -31
app.py CHANGED
@@ -181,40 +181,32 @@ def style_transfer(input_image, style_name, prompt_suffix, num_inference_steps,
181
  torch.cuda.empty_cache()
182
  return None
183
 
184
- def create_thumbnail_html():
185
- """Create HTML for thumbnail grid"""
186
- html = '<div style="display: grid; grid-template-columns: repeat(6, 1fr); gap: 10px; max-width: 800px; margin: 0 auto;">'
187
-
 
 
 
188
  styles = list(style_type_lora_dict.keys())
189
 
190
- for i, style in enumerate(styles):
191
- if i >= 24: # Limit to 24 thumbnails for 6x4 grid
192
- break
193
-
194
  thumbnail_file = thumbnail_mapping.get(style, "")
195
- style_readable = style.replace('_', ' ')
196
-
197
- html += f'''
198
- <div style="text-align: center; cursor: pointer;" onclick="document.getElementById('style_dropdown').value='{style}';
199
- var event = new Event('change', {{bubbles: true}});
200
- document.getElementById('style_dropdown').dispatchEvent(event);">
201
- <img src="file/{thumbnail_file}" alt="{style_readable}"
202
- style="width: 100%; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);
203
- transition: transform 0.2s, box-shadow 0.2s;"
204
- onmouseover="this.style.transform='scale(1.05)'; this.style.boxShadow='0 4px 8px rgba(0,0,0,0.2)';"
205
- onmouseout="this.style.transform='scale(1)'; this.style.boxShadow='0 2px 4px rgba(0,0,0,0.1)';">
206
- <p style="margin: 5px 0; font-size: 12px; font-weight: 500;">{style_readable}</p>
207
- </div>
208
- '''
209
-
210
- # Fill empty slots if needed
211
- remaining_slots = 24 - len(styles)
212
- if remaining_slots > 0 and len(styles) < 24:
213
- for _ in range(remaining_slots):
214
- html += '<div></div>'
215
 
216
- html += '</div>'
217
- return html
218
 
219
  # Create Gradio interface
220
  with gr.Blocks(title="FLUX.1 Kontext Style Transfer", theme=gr.themes.Soft()) as demo:
@@ -228,8 +220,20 @@ with gr.Blocks(title="FLUX.1 Kontext Style Transfer", theme=gr.themes.Soft()) as
228
 
229
  # Thumbnail Grid Section
230
  gr.Markdown("### 🖼️ Click a style thumbnail to select it:")
 
231
  with gr.Row():
232
- gr.HTML(create_thumbnail_html())
 
 
 
 
 
 
 
 
 
 
 
233
 
234
  gr.Markdown("---")
235
 
@@ -307,6 +311,22 @@ with gr.Blocks(title="FLUX.1 Kontext Style Transfer", theme=gr.themes.Soft()) as
307
  - Use additional instructions for fine control
308
  """)
309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  # Update style description when style changes
311
  def update_description(style):
312
  return style_descriptions.get(style, "")
 
181
  torch.cuda.empty_cache()
182
  return None
183
 
184
+ def select_style(style_name):
185
+ """Handler for thumbnail clicks"""
186
+ return style_name, style_descriptions.get(style_name, "")
187
+
188
+ def create_thumbnail_grid():
189
+ """Create a gallery of style thumbnails"""
190
+ thumbnails = []
191
  styles = list(style_type_lora_dict.keys())
192
 
193
+ for style in styles:
 
 
 
194
  thumbnail_file = thumbnail_mapping.get(style, "")
195
+ if thumbnail_file and os.path.exists(thumbnail_file):
196
+ try:
197
+ img = Image.open(thumbnail_file)
198
+ thumbnails.append((img, style.replace('_', ' ')))
199
+ except Exception as e:
200
+ print(f"Error loading thumbnail {thumbnail_file}: {e}")
201
+ # Create placeholder if thumbnail fails to load
202
+ placeholder = Image.new('RGB', (256, 256), color='lightgray')
203
+ thumbnails.append((placeholder, style.replace('_', ' ')))
204
+ else:
205
+ # Create placeholder for missing thumbnails
206
+ placeholder = Image.new('RGB', (256, 256), color='lightgray')
207
+ thumbnails.append((placeholder, style.replace('_', ' ')))
 
 
 
 
 
 
 
208
 
209
+ return thumbnails
 
210
 
211
  # Create Gradio interface
212
  with gr.Blocks(title="FLUX.1 Kontext Style Transfer", theme=gr.themes.Soft()) as demo:
 
220
 
221
  # Thumbnail Grid Section
222
  gr.Markdown("### 🖼️ Click a style thumbnail to select it:")
223
+
224
  with gr.Row():
225
+ style_gallery = gr.Gallery(
226
+ value=create_thumbnail_grid(),
227
+ label="Style Thumbnails",
228
+ show_label=False,
229
+ elem_id="style_gallery",
230
+ columns=6,
231
+ rows=4,
232
+ object_fit="cover",
233
+ height="auto",
234
+ interactive=True,
235
+ show_download_button=False
236
+ )
237
 
238
  gr.Markdown("---")
239
 
 
311
  - Use additional instructions for fine control
312
  """)
313
 
314
+ # Handle gallery selection
315
+ def on_gallery_select(evt: gr.SelectData):
316
+ """Handle thumbnail selection from gallery"""
317
+ selected_index = evt.index
318
+ styles = list(style_type_lora_dict.keys())
319
+ if 0 <= selected_index < len(styles):
320
+ selected_style = styles[selected_index]
321
+ return selected_style, style_descriptions.get(selected_style, "")
322
+ return None, None
323
+
324
+ style_gallery.select(
325
+ fn=on_gallery_select,
326
+ inputs=None,
327
+ outputs=[style_dropdown, style_info]
328
+ )
329
+
330
  # Update style description when style changes
331
  def update_description(style):
332
  return style_descriptions.get(style, "")