Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -181,40 +181,32 @@ def style_transfer(input_image, style_name, prompt_suffix, num_inference_steps,
|
|
181 |
torch.cuda.empty_cache()
|
182 |
return None
|
183 |
|
184 |
-
def
|
185 |
-
"""
|
186 |
-
|
187 |
-
|
|
|
|
|
|
|
188 |
styles = list(style_type_lora_dict.keys())
|
189 |
|
190 |
-
for
|
191 |
-
if i >= 24: # Limit to 24 thumbnails for 6x4 grid
|
192 |
-
break
|
193 |
-
|
194 |
thumbnail_file = thumbnail_mapping.get(style, "")
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
'''
|
209 |
-
|
210 |
-
# Fill empty slots if needed
|
211 |
-
remaining_slots = 24 - len(styles)
|
212 |
-
if remaining_slots > 0 and len(styles) < 24:
|
213 |
-
for _ in range(remaining_slots):
|
214 |
-
html += '<div></div>'
|
215 |
|
216 |
-
|
217 |
-
return html
|
218 |
|
219 |
# Create Gradio interface
|
220 |
with gr.Blocks(title="FLUX.1 Kontext Style Transfer", theme=gr.themes.Soft()) as demo:
|
@@ -228,8 +220,20 @@ with gr.Blocks(title="FLUX.1 Kontext Style Transfer", theme=gr.themes.Soft()) as
|
|
228 |
|
229 |
# Thumbnail Grid Section
|
230 |
gr.Markdown("### 🖼️ Click a style thumbnail to select it:")
|
|
|
231 |
with gr.Row():
|
232 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
gr.Markdown("---")
|
235 |
|
@@ -307,6 +311,22 @@ with gr.Blocks(title="FLUX.1 Kontext Style Transfer", theme=gr.themes.Soft()) as
|
|
307 |
- Use additional instructions for fine control
|
308 |
""")
|
309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
# Update style description when style changes
|
311 |
def update_description(style):
|
312 |
return style_descriptions.get(style, "")
|
|
|
181 |
torch.cuda.empty_cache()
|
182 |
return None
|
183 |
|
184 |
+
def select_style(style_name):
|
185 |
+
"""Handler for thumbnail clicks"""
|
186 |
+
return style_name, style_descriptions.get(style_name, "")
|
187 |
+
|
188 |
+
def create_thumbnail_grid():
|
189 |
+
"""Create a gallery of style thumbnails"""
|
190 |
+
thumbnails = []
|
191 |
styles = list(style_type_lora_dict.keys())
|
192 |
|
193 |
+
for style in styles:
|
|
|
|
|
|
|
194 |
thumbnail_file = thumbnail_mapping.get(style, "")
|
195 |
+
if thumbnail_file and os.path.exists(thumbnail_file):
|
196 |
+
try:
|
197 |
+
img = Image.open(thumbnail_file)
|
198 |
+
thumbnails.append((img, style.replace('_', ' ')))
|
199 |
+
except Exception as e:
|
200 |
+
print(f"Error loading thumbnail {thumbnail_file}: {e}")
|
201 |
+
# Create placeholder if thumbnail fails to load
|
202 |
+
placeholder = Image.new('RGB', (256, 256), color='lightgray')
|
203 |
+
thumbnails.append((placeholder, style.replace('_', ' ')))
|
204 |
+
else:
|
205 |
+
# Create placeholder for missing thumbnails
|
206 |
+
placeholder = Image.new('RGB', (256, 256), color='lightgray')
|
207 |
+
thumbnails.append((placeholder, style.replace('_', ' ')))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
+
return thumbnails
|
|
|
210 |
|
211 |
# Create Gradio interface
|
212 |
with gr.Blocks(title="FLUX.1 Kontext Style Transfer", theme=gr.themes.Soft()) as demo:
|
|
|
220 |
|
221 |
# Thumbnail Grid Section
|
222 |
gr.Markdown("### 🖼️ Click a style thumbnail to select it:")
|
223 |
+
|
224 |
with gr.Row():
|
225 |
+
style_gallery = gr.Gallery(
|
226 |
+
value=create_thumbnail_grid(),
|
227 |
+
label="Style Thumbnails",
|
228 |
+
show_label=False,
|
229 |
+
elem_id="style_gallery",
|
230 |
+
columns=6,
|
231 |
+
rows=4,
|
232 |
+
object_fit="cover",
|
233 |
+
height="auto",
|
234 |
+
interactive=True,
|
235 |
+
show_download_button=False
|
236 |
+
)
|
237 |
|
238 |
gr.Markdown("---")
|
239 |
|
|
|
311 |
- Use additional instructions for fine control
|
312 |
""")
|
313 |
|
314 |
+
# Handle gallery selection
|
315 |
+
def on_gallery_select(evt: gr.SelectData):
|
316 |
+
"""Handle thumbnail selection from gallery"""
|
317 |
+
selected_index = evt.index
|
318 |
+
styles = list(style_type_lora_dict.keys())
|
319 |
+
if 0 <= selected_index < len(styles):
|
320 |
+
selected_style = styles[selected_index]
|
321 |
+
return selected_style, style_descriptions.get(selected_style, "")
|
322 |
+
return None, None
|
323 |
+
|
324 |
+
style_gallery.select(
|
325 |
+
fn=on_gallery_select,
|
326 |
+
inputs=None,
|
327 |
+
outputs=[style_dropdown, style_info]
|
328 |
+
)
|
329 |
+
|
330 |
# Update style description when style changes
|
331 |
def update_description(style):
|
332 |
return style_descriptions.get(style, "")
|