chengzeyi commited on
Commit
14e2b9a
Β·
verified Β·
1 Parent(s): 6e4d5c2

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +53 -17
app.py CHANGED
@@ -9,6 +9,7 @@ from dotenv import load_dotenv
9
  import gradio as gr
10
  import random
11
  import torch
 
12
  from PIL import Image, ImageDraw, ImageFont
13
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
14
 
@@ -106,6 +107,11 @@ def image_to_base64(file_path):
106
  return base64.b64encode(f.read()).decode()
107
 
108
 
 
 
 
 
 
109
  def generate_image(image_file,
110
  prompt,
111
  seed,
@@ -116,13 +122,13 @@ def generate_image(image_file,
116
  safety_level = classify_prompt(prompt)
117
  if safety_level != 0:
118
  error_img = create_error_image(CLASS_NAMES[safety_level])
119
- yield f"❌ Blocked: {CLASS_NAMES[safety_level]}", error_img, ""
120
  return
121
 
122
  if not rate_limiter.check(session_id):
123
  error_img = create_error_image(
124
  "Hourly limit exceeded (20 requests)")
125
- yield "❌ Too many requests, please try again later", error_img, ""
126
  return
127
 
128
  session = session_manager.get_session(session_id)
@@ -138,14 +144,15 @@ def generate_image(image_file,
138
  error_messages.append("Prompt cannot be empty")
139
  if error_messages:
140
  error_img = create_error_image(" | ".join(error_messages))
141
- yield "❌ Input validation failed", error_img, ""
142
  return
143
 
144
  try:
145
  base64_image = image_to_base64(image_file)
 
146
  except Exception as e:
147
  error_img = create_error_image(f"File processing failed: {str(e)}")
148
- yield "❌ File processing failed", error_img, ""
149
  return
150
 
151
  headers = {
@@ -180,20 +187,20 @@ def generate_image(image_file,
180
 
181
  if status == "completed":
182
  elapsed = time.time() - start_time
183
- image_url = data["outputs"][0]
184
- session["history"].append(image_url)
185
- yield f"πŸŽ‰ Generation successful! Time taken {elapsed:.1f}s", image_url, image_url
186
  return
187
  elif status == "failed":
188
  raise Exception(data.get("error", "Unknown error"))
189
  else:
190
- yield f"⏳ Current status: {status.capitalize()}...", None, None
191
 
192
  raise Exception("Generation timed out")
193
 
194
  except Exception as e:
195
  error_img = create_error_image(str(e))
196
- yield f"❌ Generation failed: {str(e)}", error_img, ""
197
 
198
 
199
  def cleanup_task():
@@ -202,6 +209,9 @@ def cleanup_task():
202
  time.sleep(3600)
203
 
204
 
 
 
 
205
  with gr.Blocks(theme=gr.themes.Soft(),
206
  css="""
207
  .status-box { padding: 10px; border-radius: 5px; margin: 5px; }
@@ -212,27 +222,29 @@ with gr.Blocks(theme=gr.themes.Soft(),
212
 
213
  session_id = gr.State(str(uuid.uuid4()))
214
 
215
- gr.Markdown("# πŸ–ΌοΈFLUX Kontext Dev Ultra Fast Live On Wavespeed AI")
216
  gr.Markdown(
217
  "FLUX Kontext Dev is a new SOTA image editing model published by Black Forest Labs. We have deployed it on [WaveSpeedAI](https://wavespeed.ai/) for ultra-fast image editing. You can use it to edit images in various styles, add objects, or even change the mood of the image. It supports both text prompts and image inputs."
218
  )
219
  gr.Markdown(
220
- "[FLUX Kontext Dev on WaveSpeedAI](https://wavespeed.ai/models/wavespeed-ai/flux-kontext-dev)"
 
 
 
221
  )
222
  gr.Markdown(
223
- "[FLUX Kontext Dev Ultra Fast on WaveSpeedAI](https://wavespeed.ai/models/wavespeed-ai/flux-kontext-dev-ultra-fast)"
224
  )
225
 
226
  with gr.Row():
227
  with gr.Column(scale=1):
 
 
 
228
  image_file = gr.Image(label="Upload Image",
229
  type="filepath",
230
  sources=["upload"],
231
  interactive=True,
232
  image_mode="RGB")
233
- prompt = gr.Textbox(label="Prompt",
234
- placeholder="Please enter your prompt...",
235
- lines=3)
236
  seed = gr.Number(label="seed",
237
  value=-1,
238
  minimum=-1,
@@ -243,11 +255,11 @@ with gr.Blocks(theme=gr.themes.Soft(),
243
  value=True,
244
  interactive=False)
245
  with gr.Column(scale=1):
 
246
  output_image = gr.Image(label="Generated Result")
247
  output_url = gr.Textbox(label="Image URL",
248
  interactive=True,
249
  visible=False)
250
- status = gr.Textbox(label="Status", elem_classes=["status-box"])
251
  submit_btn = gr.Button("Start Generation", variant="primary")
252
  gr.Examples(
253
  examples=[
@@ -275,12 +287,36 @@ with gr.Blocks(theme=gr.themes.Soft(),
275
  inputs=[prompt, image_file],
276
  label="Examples")
277
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  random_btn.click(fn=lambda: random.randint(0, 999999), outputs=seed)
279
 
280
  submit_btn.click(
281
  generate_image,
282
  inputs=[image_file, prompt, seed, session_id, enable_safety],
283
- outputs=[status, output_image, output_url],
284
  api_name=False,
285
  )
286
 
 
9
  import gradio as gr
10
  import random
11
  import torch
12
+ import io
13
  from PIL import Image, ImageDraw, ImageFont
14
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
15
 
 
107
  return base64.b64encode(f.read()).decode()
108
 
109
 
110
+ def decode_base64_to_image(base64_str):
111
+ image_data = base64.b64decode(base64_str)
112
+ return Image.open(io.BytesIO(image_data))
113
+
114
+
115
  def generate_image(image_file,
116
  prompt,
117
  seed,
 
122
  safety_level = classify_prompt(prompt)
123
  if safety_level != 0:
124
  error_img = create_error_image(CLASS_NAMES[safety_level])
125
+ yield f"❌ Blocked: {CLASS_NAMES[safety_level]}", error_img, "", None
126
  return
127
 
128
  if not rate_limiter.check(session_id):
129
  error_img = create_error_image(
130
  "Hourly limit exceeded (20 requests)")
131
+ yield "❌ Too many requests, please try again later", error_img, "", None
132
  return
133
 
134
  session = session_manager.get_session(session_id)
 
144
  error_messages.append("Prompt cannot be empty")
145
  if error_messages:
146
  error_img = create_error_image(" | ".join(error_messages))
147
+ yield "❌ Input validation failed", error_img, "", None
148
  return
149
 
150
  try:
151
  base64_image = image_to_base64(image_file)
152
+ input_image = decode_base64_to_image(base64_image)
153
  except Exception as e:
154
  error_img = create_error_image(f"File processing failed: {str(e)}")
155
+ yield "❌ File processing failed", error_img, "", None
156
  return
157
 
158
  headers = {
 
187
 
188
  if status == "completed":
189
  elapsed = time.time() - start_time
190
+ output_url = data["outputs"][0]
191
+ session["history"].append(output_url)
192
+ yield f"πŸŽ‰ Generation successful! Time taken {elapsed:.1f}s", output_url, output_url, update_recent_gallery(prompt, input_image, output_url)
193
  return
194
  elif status == "failed":
195
  raise Exception(data.get("error", "Unknown error"))
196
  else:
197
+ yield f"⏳ Current status: {status.capitalize()}...", None, None, None
198
 
199
  raise Exception("Generation timed out")
200
 
201
  except Exception as e:
202
  error_img = create_error_image(str(e))
203
+ yield f"❌ Generation failed: {str(e)}", error_img, "", None
204
 
205
 
206
  def cleanup_task():
 
209
  time.sleep(3600)
210
 
211
 
212
+ # Store recent generations
213
+ recent_generations = []
214
+
215
  with gr.Blocks(theme=gr.themes.Soft(),
216
  css="""
217
  .status-box { padding: 10px; border-radius: 5px; margin: 5px; }
 
222
 
223
  session_id = gr.State(str(uuid.uuid4()))
224
 
225
+ gr.Markdown("# πŸ–ΌοΈFLUX Kontext Dev Ultra Fast Live")
226
  gr.Markdown(
227
  "FLUX Kontext Dev is a new SOTA image editing model published by Black Forest Labs. We have deployed it on [WaveSpeedAI](https://wavespeed.ai/) for ultra-fast image editing. You can use it to edit images in various styles, add objects, or even change the mood of the image. It supports both text prompts and image inputs."
228
  )
229
  gr.Markdown(
230
+ "- [FLUX Kontext Dev on WaveSpeedAI](https://wavespeed.ai/models/wavespeed-ai/flux-kontext-dev)"
231
+ "- [FLUX Kontext Dev LoRA on WaveSpeedAI](https://wavespeed.ai/models/wavespeed-ai/flux-kontext-dev-lora)"
232
+ "- [FLUX Kontext Dev Ultra Fast on WaveSpeedAI](https://wavespeed.ai/models/wavespeed-ai/flux-kontext-dev-ultra-fast)"
233
+ "- [FLUX Kontext Dev LoRA Ultra Fast on WaveSpeedAI](https://wavespeed.ai/models/wavespeed-ai/flux-kontext-dev-lora-ultra-fast)"
234
  )
235
  gr.Markdown(
 
236
  )
237
 
238
  with gr.Row():
239
  with gr.Column(scale=1):
240
+ prompt = gr.Textbox(label="Prompt",
241
+ placeholder="Please enter your prompt...",
242
+ lines=3)
243
  image_file = gr.Image(label="Upload Image",
244
  type="filepath",
245
  sources=["upload"],
246
  interactive=True,
247
  image_mode="RGB")
 
 
 
248
  seed = gr.Number(label="seed",
249
  value=-1,
250
  minimum=-1,
 
255
  value=True,
256
  interactive=False)
257
  with gr.Column(scale=1):
258
+ status = gr.Textbox(label="Status", elem_classes=["status-box"])
259
  output_image = gr.Image(label="Generated Result")
260
  output_url = gr.Textbox(label="Image URL",
261
  interactive=True,
262
  visible=False)
 
263
  submit_btn = gr.Button("Start Generation", variant="primary")
264
  gr.Examples(
265
  examples=[
 
287
  inputs=[prompt, image_file],
288
  label="Examples")
289
 
290
+ with gr.Accordion("Recent Generations (last 16)", open=False):
291
+ recent_gallery = gr.Gallery(label="Prompt and Output",
292
+ columns=3,
293
+ interactive=False)
294
+
295
+ def get_recent_gallery_items():
296
+ gallery_items = []
297
+ for r in reversed(recent_generations):
298
+ if any(x is None for x in r.values()):
299
+ continue
300
+ gallery_items.append((r["input"], f"Input: {r['prompt']}"))
301
+ gallery_items.append((r["output"], f"Output: {r['prompt']}"))
302
+ return gr.update(value=gallery_items)
303
+
304
+ def update_recent_gallery(prompt, input_image, output_image):
305
+ recent_generations.append({
306
+ "prompt": prompt,
307
+ "input": input_image,
308
+ "output": output_image,
309
+ })
310
+ if len(recent_generations) > 16:
311
+ recent_generations.pop(0)
312
+ return get_recent_gallery_items()
313
+
314
  random_btn.click(fn=lambda: random.randint(0, 999999), outputs=seed)
315
 
316
  submit_btn.click(
317
  generate_image,
318
  inputs=[image_file, prompt, seed, session_id, enable_safety],
319
+ outputs=[status, output_image, output_url, recent_gallery],
320
  api_name=False,
321
  )
322