lionelgarnier commited on
Commit
962f2bc
·
1 Parent(s): 5e2ea0f

update image_to_3d to support multiple images and handle numpy arrays

Browse files
Files changed (1) hide show
  1. app.py +20 -16
app.py CHANGED
@@ -86,7 +86,8 @@ def get_image_gen_pipeline():
86
  device = "cuda" if torch.cuda.is_available() else "cpu"
87
  dtype = torch.bfloat16
88
  _image_gen_pipeline = DiffusionPipeline.from_pretrained(
89
- "black-forest-labs/FLUX.1-schnell",
 
90
  torch_dtype=dtype,
91
  ).to(device)
92
 
@@ -320,25 +321,29 @@ def image_to_3d(
320
  slat_sampling_steps: int,
321
  ) -> Tuple[dict, str]:
322
  try:
323
- # Use a fixed temp directory instead of user-specific
324
- temp_dir = os.path.join(TMP_DIR, "temp_output")
325
- os.makedirs(temp_dir, exist_ok=True)
326
 
327
- # Get the pipeline using the getter function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  pipeline = get_trellis_pipeline()
329
  if pipeline is None:
330
  return None, "Trellis pipeline is unavailable."
331
-
332
- # Call cuda() here in the GPU worker process
333
  pipeline.cuda()
334
-
335
- # Convert image to the right format if needed
336
- if isinstance(image, np.ndarray):
337
- image = Image.fromarray(image.astype('uint8'))
338
-
339
- # Make sure we have a list of images as expected by the pipeline
340
- input_image = [image]
341
-
342
  outputs = pipeline.run(
343
  input_image,
344
  seed=seed,
@@ -627,4 +632,3 @@ if __name__ == "__main__":
627
 
628
  demo = create_interface()
629
  demo.launch(debug=True)
630
-
 
86
  device = "cuda" if torch.cuda.is_available() else "cpu"
87
  dtype = torch.bfloat16
88
  _image_gen_pipeline = DiffusionPipeline.from_pretrained(
89
+ # "black-forest-labs/FLUX.1-schnell",
90
+ "black-forest-labs/FLUX.1-dev",
91
  torch_dtype=dtype,
92
  ).to(device)
93
 
 
321
  slat_sampling_steps: int,
322
  ) -> Tuple[dict, str]:
323
  try:
324
+ if isinstance(image, dict) and "image" in image:
325
+ image = image["image"]
 
326
 
327
+ # If user passed multiple images
328
+ if isinstance(image, list):
329
+ input_image = []
330
+ for img in image:
331
+ if isinstance(img, dict) and "image" in img:
332
+ img = img["image"]
333
+ if isinstance(img, np.ndarray):
334
+ img = Image.fromarray(img.astype("uint8"))
335
+ input_image.append(img)
336
+ else:
337
+ # Single image
338
+ if isinstance(image, np.ndarray):
339
+ image = Image.fromarray(image.astype("uint8"))
340
+ input_image = [image]
341
+
342
  pipeline = get_trellis_pipeline()
343
  if pipeline is None:
344
  return None, "Trellis pipeline is unavailable."
 
 
345
  pipeline.cuda()
346
+
 
 
 
 
 
 
 
347
  outputs = pipeline.run(
348
  input_image,
349
  seed=seed,
 
632
 
633
  demo = create_interface()
634
  demo.launch(debug=True)