Spaces:
Build error
Build error
Fixed tensor dimension issues for multi-view processing -- again
Browse files
hy3dshape/hy3dshape/pipelines.py
CHANGED
|
@@ -500,10 +500,41 @@ class Hunyuan3DDiTPipeline:
|
|
| 500 |
|
| 501 |
# Handle dictionary input (multi-view mode)
|
| 502 |
if isinstance(image, dict):
|
| 503 |
-
#
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
|
| 508 |
if not isinstance(image, list):
|
| 509 |
image = [image]
|
|
|
|
| 500 |
|
| 501 |
# Handle dictionary input (multi-view mode)
|
| 502 |
if isinstance(image, dict):
|
| 503 |
+
# Process each view individually with the single-image processor
|
| 504 |
+
# and then combine them appropriately
|
| 505 |
+
processed_views = []
|
| 506 |
+
view_order = []
|
| 507 |
+
|
| 508 |
+
# Define the standard view order
|
| 509 |
+
view_mapping = {'front': 0, 'left': 1, 'back': 2, 'right': 3}
|
| 510 |
+
|
| 511 |
+
# Sort views by their standard order
|
| 512 |
+
sorted_views = sorted(image.items(), key=lambda x: view_mapping.get(x[0], 999))
|
| 513 |
+
|
| 514 |
+
for view_name, view_image in sorted_views:
|
| 515 |
+
# Process each view individually
|
| 516 |
+
view_output = self.image_processor(view_image)
|
| 517 |
+
processed_views.append(view_output)
|
| 518 |
+
view_order.append(view_mapping.get(view_name, 0))
|
| 519 |
+
|
| 520 |
+
# Combine all views into a single batch
|
| 521 |
+
# Each view_output has shape [1, 3, H, W], we want to concatenate along batch dimension
|
| 522 |
+
combined_images = []
|
| 523 |
+
combined_masks = []
|
| 524 |
+
|
| 525 |
+
for view_output in processed_views:
|
| 526 |
+
combined_images.append(view_output['image'])
|
| 527 |
+
combined_masks.append(view_output['mask'])
|
| 528 |
+
|
| 529 |
+
# Concatenate along batch dimension: [num_views, 3, H, W]
|
| 530 |
+
final_image = torch.cat(combined_images, dim=0)
|
| 531 |
+
final_mask = torch.cat(combined_masks, dim=0)
|
| 532 |
+
|
| 533 |
+
return {
|
| 534 |
+
'image': final_image,
|
| 535 |
+
'mask': final_mask,
|
| 536 |
+
'view_idxs': view_order
|
| 537 |
+
}
|
| 538 |
|
| 539 |
if not isinstance(image, list):
|
| 540 |
image = [image]
|