lionelgarnier commited on
Commit
5c85703
·
1 Parent(s): 962f2bc

cahnge image type to PIL

Browse files
Files changed (1) hide show
  1. app.py +10 -23
app.py CHANGED
@@ -62,7 +62,7 @@ def end_session(req: gr.Request):
62
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
63
  shutil.rmtree(user_dir)
64
 
65
- def preprocess_image(image):
66
  trellis = get_trellis_pipeline()
67
  if trellis is None:
68
  # If the pipeline is not loaded, just return the original image
@@ -313,7 +313,7 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
313
 
314
  @spaces.GPU
315
  def image_to_3d(
316
- image,
317
  seed: int,
318
  ss_guidance_strength: float,
319
  ss_sampling_steps: int,
@@ -321,31 +321,18 @@ def image_to_3d(
321
  slat_sampling_steps: int,
322
  ) -> Tuple[dict, str]:
323
  try:
324
- if isinstance(image, dict) and "image" in image:
325
- image = image["image"]
326
-
327
- # If user passed multiple images
328
- if isinstance(image, list):
329
- input_image = []
330
- for img in image:
331
- if isinstance(img, dict) and "image" in img:
332
- img = img["image"]
333
- if isinstance(img, np.ndarray):
334
- img = Image.fromarray(img.astype("uint8"))
335
- input_image.append(img)
336
- else:
337
- # Single image
338
- if isinstance(image, np.ndarray):
339
- image = Image.fromarray(image.astype("uint8"))
340
- input_image = [image]
341
-
342
  pipeline = get_trellis_pipeline()
343
  if pipeline is None:
344
  return None, "Trellis pipeline is unavailable."
345
  pipeline.cuda()
346
 
 
 
 
 
347
  outputs = pipeline.run(
348
- input_image,
349
  seed=seed,
350
  formats=["gaussian", "mesh"],
351
  preprocess_image=False,
@@ -473,7 +460,7 @@ def create_interface():
473
  max_length=2048,
474
  )
475
  visual_button = gr.Button("Create visual with Flux")
476
- generated_image = gr.Image(show_label=False)
477
 
478
  preprocessed_button = gr.Button("Preprocess image")
479
  preprocessed_image = gr.Image(show_label=False)
@@ -586,7 +573,7 @@ def create_interface():
586
  gr.on(
587
  triggers=[gen3d_button.click],
588
  fn=image_to_3d,
589
- inputs=[preprocessed_image, trellis_seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
590
  outputs=[output_state, video_output],
591
  )
592
  # .then(
 
62
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
63
  shutil.rmtree(user_dir)
64
 
65
+ def preprocess_image(image: Image.Image) -> Image.Image:
66
  trellis = get_trellis_pipeline()
67
  if trellis is None:
68
  # If the pipeline is not loaded, just return the original image
 
313
 
314
  @spaces.GPU
315
  def image_to_3d(
316
+ image: Image.Image,
317
  seed: int,
318
  ss_guidance_strength: float,
319
  ss_sampling_steps: int,
 
321
  slat_sampling_steps: int,
322
  ) -> Tuple[dict, str]:
323
  try:
324
+ # Load the Trellis pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  pipeline = get_trellis_pipeline()
326
  if pipeline is None:
327
  return None, "Trellis pipeline is unavailable."
328
  pipeline.cuda()
329
 
330
+ # Preprocess image
331
+ image = preprocess_image(image)
332
+
333
+ # Run the pipeline
334
  outputs = pipeline.run(
335
+ image,
336
  seed=seed,
337
  formats=["gaussian", "mesh"],
338
  preprocess_image=False,
 
460
  max_length=2048,
461
  )
462
  visual_button = gr.Button("Create visual with Flux")
463
+ generated_image = gr.Image(show_label=False, format="png", image_mode="RGBA", type="pil", height=300)
464
 
465
  preprocessed_button = gr.Button("Preprocess image")
466
  preprocessed_image = gr.Image(show_label=False)
 
573
  gr.on(
574
  triggers=[gen3d_button.click],
575
  fn=image_to_3d,
576
+ inputs=[generated_image, trellis_seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
577
  outputs=[output_state, video_output],
578
  )
579
  # .then(