linoyts HF Staff commited on
Commit
5755f6a
·
verified ·
1 Parent(s): 47410cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -7
app.py CHANGED
@@ -180,10 +180,71 @@ except Exception as e:
180
  # --- UI Constants and Helpers ---
181
  MAX_SEED = np.iinfo(np.int32).max
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  # --- Main Inference Function ---
184
  @spaces.GPU(duration=60)
185
  def infer(
186
- image,
187
  prompt,
188
  seed=42,
189
  randomize_seed=False,
@@ -203,6 +264,20 @@ def infer(
203
 
204
  # Set up the generator for reproducibility
205
  generator = torch.Generator(device=device).manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
  print(f"Original prompt: '{prompt}'")
208
  print(f"Negative Prompt: '{negative_prompt}'")
@@ -215,7 +290,7 @@ def infer(
215
  # Generate the edited image - always generate just 1 image
216
  try:
217
  images = pipe(
218
- image,
219
  prompt=prompt,
220
  negative_prompt=negative_prompt,
221
  num_inference_steps=num_inference_steps,
@@ -268,10 +343,16 @@ with gr.Blocks(css=css) as demo:
268
 
269
  with gr.Row():
270
  with gr.Column():
271
- input_image = gr.Image(
272
- label="Input Image",
273
- show_label=True,
274
- type="pil"
 
 
 
 
 
 
275
  )
276
  # Changed from Gallery to Image
277
  result = gr.Image(
@@ -329,7 +410,7 @@ with gr.Blocks(css=css) as demo:
329
  triggers=[run_button.click, prompt.submit],
330
  fn=infer,
331
  inputs=[
332
- input_image,
333
  prompt,
334
  seed,
335
  randomize_seed,
 
180
  # --- UI Constants and Helpers ---
181
  MAX_SEED = np.iinfo(np.int32).max
182
 
183
+
184
+ def concatenate_images(images, direction="horizontal"):
185
+ """
186
+ Concatenate multiple PIL images either horizontally or vertically.
187
+
188
+ Args:
189
+ images: List of PIL Images
190
+ direction: "horizontal" or "vertical"
191
+
192
+ Returns:
193
+ PIL Image: Concatenated image
194
+ """
195
+ if not images:
196
+ return None
197
+
198
+ # Filter out None images
199
+ valid_images = [img for img in images if img is not None]
200
+
201
+ if not valid_images:
202
+ return None
203
+
204
+ if len(valid_images) == 1:
205
+ return valid_images[0].convert("RGB")
206
+
207
+ # Convert all images to RGB
208
+ valid_images = [img.convert("RGB") for img in valid_images]
209
+
210
+ if direction == "horizontal":
211
+ # Calculate total width and max height
212
+ total_width = sum(img.width for img in valid_images)
213
+ max_height = max(img.height for img in valid_images)
214
+
215
+ # Create new image
216
+ concatenated = Image.new('RGB', (total_width, max_height), (255, 255, 255))
217
+
218
+ # Paste images
219
+ x_offset = 0
220
+ for img in valid_images:
221
+ # Center image vertically if heights differ
222
+ y_offset = (max_height - img.height) // 2
223
+ concatenated.paste(img, (x_offset, y_offset))
224
+ x_offset += img.width
225
+
226
+ else: # vertical
227
+ # Calculate max width and total height
228
+ max_width = max(img.width for img in valid_images)
229
+ total_height = sum(img.height for img in valid_images)
230
+
231
+ # Create new image
232
+ concatenated = Image.new('RGB', (max_width, total_height), (255, 255, 255))
233
+
234
+ # Paste images
235
+ y_offset = 0
236
+ for img in valid_images:
237
+ # Center image horizontally if widths differ
238
+ x_offset = (max_width - img.width) // 2
239
+ concatenated.paste(img, (x_offset, y_offset))
240
+ y_offset += img.height
241
+
242
+ return concatenated
243
+
244
  # --- Main Inference Function ---
245
  @spaces.GPU(duration=60)
246
  def infer(
247
+ input_images,
248
  prompt,
249
  seed=42,
250
  randomize_seed=False,
 
264
 
265
  # Set up the generator for reproducibility
266
  generator = torch.Generator(device=device).manual_seed(seed)
267
+
268
+ # Handle input_images - it could be a single image or a list of images
269
+ if input_images is None:
270
+ raise gr.Error("Please upload at least one image.")
271
+
272
+ # If it's a single image (not a list), convert to list
273
+ if not isinstance(input_images, list):
274
+ input_images = [input_images]
275
+
276
+ # Concatenate images horizontally
277
+ concatenated_image = concatenate_images(input_images, "horizontal")
278
+
279
+ if concatenated_image is None:
280
+ raise gr.Error("Failed to process the input images.")
281
 
282
  print(f"Original prompt: '{prompt}'")
283
  print(f"Negative Prompt: '{negative_prompt}'")
 
290
  # Generate the edited image - always generate just 1 image
291
  try:
292
  images = pipe(
293
+ concatenated_image,
294
  prompt=prompt,
295
  negative_prompt=negative_prompt,
296
  num_inference_steps=num_inference_steps,
 
343
 
344
  with gr.Row():
345
  with gr.Column():
346
+ input_images = gr.Gallery(
347
+ label="Upload image(s) for editing",
348
+ show_label=True,
349
+ elem_id="gallery_input",
350
+ columns=3,
351
+ rows=2,
352
+ object_fit="contain",
353
+ height="auto",
354
+ file_types=['image'],
355
+ type='pil'
356
  )
357
  # Changed from Gallery to Image
358
  result = gr.Image(
 
410
  triggers=[run_button.click, prompt.submit],
411
  fn=infer,
412
  inputs=[
413
+ input_images,
414
  prompt,
415
  seed,
416
  randomize_seed,