diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..42f2bf2587ce605108336e110346ec76abb25b34 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,46 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/images/np10_b07dbghsn7.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np10_cc486e491a2c499f9fd2aad2b02c6ccb.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np11_0603aac4bbf241c383115f9da3e38d80.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np11_fa2a78dddbbf44589ebe1a315c5077e7.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np12_6f9431c68ffa44d0b2e7635d646c4000.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np12_8a68419f5d454a28a7bc36da4395be60.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np13_39c0fa16ed324b54a605dcdbcd80797c.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np13_b07b4d858z.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np13_d8258903a23e4a368ede8fd584919b68.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np14_2faf0a30b1614a9ba029def4e63270d0.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np14_b8f46cf7daca419a87ac8d131bad056f.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np15_56f1ef2b84514e5e8c11a42aaab6cda7.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np2_2b7c6f4109fa4f84919bd30c367c663b.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np2_3cd2cdea87f34bd7acb096f5178b3760.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np3_2f6ab901c5a84ed6bbdf85a67b22a2ee.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np3_5bbc755b07924fcba6ccd702191f90df.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np3_7ba4129133dc46d0818930adb53c0315.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np3_7c00a1667b2c4413b4dcb24500f05892.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np3_be2e5af368e748b88871ed5776e951bb.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np3_c25308e36b3a4c6d9745e01fb34a93a1.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np3_d2e1aa2d2b4b424282d3b643fc8169cf.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np4_2444ea17f3a448b1bb7e2a74b276f015.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np4_4d6e4d3dd9194633ac86285ada1017ad.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np4_7bd5d25aa77b4fb18e780d7a4c97d342.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np4_80e586911397457aa9245eed1eb03abe.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np4_ca27601df1384a7aa152baacfb072306.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np5_23ae06bb5cf84e13ae973721fa5f5625.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np5_b81f29e567ea4db48014f89c9079e403.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np6_b7295d1f9c484a84a53f7ba62ead149e.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np6_bd2716459c774dc48fc793f1b76511e8.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np7_1749ad163235411295ed3342d024f1ac.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np7_1c004909dedb4ebe8db69b4d7b077434.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np7_8d3e568f059244e19f3b5f7e789cccb2.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np8_a7d7c17eadc54aa5a0094415025463ff.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np8_d84dc4fc47614a4687a774152b343ddd.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np9_5cc2b8234ff04a2aa29d5f0a0393ed0c.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np9_8031fc7690e640038ff6a6766e97f19d.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np9_94b4b6d936214b3f988bd4b520da4f53.png filter=lfs diff=lfs merge=lfs -text
+assets/images/np9_dd0ec139989c430fb6572fa024ae1c20.png filter=lfs diff=lfs merge=lfs -text
+assets/objects/scissors.glb filter=lfs diff=lfs merge=lfs -text
+assets/objects/sword.glb filter=lfs diff=lfs merge=lfs -text
+assets/robot.gif filter=lfs diff=lfs merge=lfs -text
+assets/teaser.png filter=lfs diff=lfs merge=lfs -text
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..37cf242ed4b5868e90e6b46e703d11f30aeaf759
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2025 Yuchen Lin
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/assets/images/np10_b07dbghsn7.png b/assets/images/np10_b07dbghsn7.png
new file mode 100644
index 0000000000000000000000000000000000000000..ce65f571a2c94dda903526ff3275bd507f76ea5e
--- /dev/null
+++ b/assets/images/np10_b07dbghsn7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4f77d44e70e4c3071bbe28bee74a71d7726fc477889e9a44339cdac38a926d86
+size 243201
diff --git a/assets/images/np10_cc486e491a2c499f9fd2aad2b02c6ccb.png b/assets/images/np10_cc486e491a2c499f9fd2aad2b02c6ccb.png
new file mode 100644
index 0000000000000000000000000000000000000000..11c1380556aceb47e762aeefedf7fe9acc8aa63d
--- /dev/null
+++ b/assets/images/np10_cc486e491a2c499f9fd2aad2b02c6ccb.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8f315756c12fc33dd6d1eb42a0331e5af73bdcc72bc585067d7cb60ebbfc84dd
+size 405383
diff --git a/assets/images/np11_0603aac4bbf241c383115f9da3e38d80.png b/assets/images/np11_0603aac4bbf241c383115f9da3e38d80.png
new file mode 100644
index 0000000000000000000000000000000000000000..66b877550b7afbc9af5ea150999e5ac5f73d6c86
--- /dev/null
+++ b/assets/images/np11_0603aac4bbf241c383115f9da3e38d80.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5afaae6860da9e3ae04d9ce2e04603ecf48ef43c9e082fe0bac9eaa4c546cde0
+size 679688
diff --git a/assets/images/np11_fa2a78dddbbf44589ebe1a315c5077e7.png b/assets/images/np11_fa2a78dddbbf44589ebe1a315c5077e7.png
new file mode 100644
index 0000000000000000000000000000000000000000..65c2ef0250af1cafe46af75ae4615930b597f679
--- /dev/null
+++ b/assets/images/np11_fa2a78dddbbf44589ebe1a315c5077e7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bde7a44b23d22c34a196d6d873e1edccc95595107e135e028e1decb75a74b649
+size 202970
diff --git a/assets/images/np12_6f9431c68ffa44d0b2e7635d646c4000.png b/assets/images/np12_6f9431c68ffa44d0b2e7635d646c4000.png
new file mode 100644
index 0000000000000000000000000000000000000000..fc702dfaf3953815cefdc6171b716f2dfc3a0026
--- /dev/null
+++ b/assets/images/np12_6f9431c68ffa44d0b2e7635d646c4000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:08d5564bee676697023c387e32f014ead4ddd1d4b3ff6b246bfc2896fef24a54
+size 1233255
diff --git a/assets/images/np12_8a68419f5d454a28a7bc36da4395be60.png b/assets/images/np12_8a68419f5d454a28a7bc36da4395be60.png
new file mode 100644
index 0000000000000000000000000000000000000000..cb3020192a1f83396254acaf1fc462e3ca2b229e
--- /dev/null
+++ b/assets/images/np12_8a68419f5d454a28a7bc36da4395be60.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dcc0e02333cb29bc743b5d1e8223f1a5064a3e674cf939e641490b16d0ae80c7
+size 1252988
diff --git a/assets/images/np13_39c0fa16ed324b54a605dcdbcd80797c.png b/assets/images/np13_39c0fa16ed324b54a605dcdbcd80797c.png
new file mode 100644
index 0000000000000000000000000000000000000000..59b80b26b256953136c242fce229b43ca5b7e855
--- /dev/null
+++ b/assets/images/np13_39c0fa16ed324b54a605dcdbcd80797c.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0e0647e48fb9255eb08022b414f5bf3102b1d9f5d576c5a4bf7c779b98fba601
+size 437956
diff --git a/assets/images/np13_b07b4d858z.png b/assets/images/np13_b07b4d858z.png
new file mode 100644
index 0000000000000000000000000000000000000000..0919102d67b6667e80a7cf9e6c79847c51640a6d
--- /dev/null
+++ b/assets/images/np13_b07b4d858z.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cad406ed0f5f4b53b05564a3e1bf2641b81f0e6b249732012b724593dbc0cb57
+size 277280
diff --git a/assets/images/np13_d8258903a23e4a368ede8fd584919b68.png b/assets/images/np13_d8258903a23e4a368ede8fd584919b68.png
new file mode 100644
index 0000000000000000000000000000000000000000..95a995a726983bf2caefa0587c8b99c085321ea8
--- /dev/null
+++ b/assets/images/np13_d8258903a23e4a368ede8fd584919b68.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:46e60a55e2a51cac69751bb74c44dc901aca067664a2e01922b90e6169801c09
+size 767501
diff --git a/assets/images/np14_2faf0a30b1614a9ba029def4e63270d0.png b/assets/images/np14_2faf0a30b1614a9ba029def4e63270d0.png
new file mode 100644
index 0000000000000000000000000000000000000000..75ec63312a18f684760f4c52af405813bbcf8fe7
--- /dev/null
+++ b/assets/images/np14_2faf0a30b1614a9ba029def4e63270d0.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:957d00379e6f789ae8d64d69e06e2c53f13f027f664a8f1282a84e6887997f67
+size 610765
diff --git a/assets/images/np14_b8f46cf7daca419a87ac8d131bad056f.png b/assets/images/np14_b8f46cf7daca419a87ac8d131bad056f.png
new file mode 100644
index 0000000000000000000000000000000000000000..a4b622a161e710916681e68785adc5428e9c62ac
--- /dev/null
+++ b/assets/images/np14_b8f46cf7daca419a87ac8d131bad056f.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:acfe00f60bbc3c49d707c0d36ae04f25646627b17aaaa3337b10416474b88085
+size 772583
diff --git a/assets/images/np15_56f1ef2b84514e5e8c11a42aaab6cda7.png b/assets/images/np15_56f1ef2b84514e5e8c11a42aaab6cda7.png
new file mode 100644
index 0000000000000000000000000000000000000000..0c88e142c3c707f424e77ff43e6769227cfa7f51
--- /dev/null
+++ b/assets/images/np15_56f1ef2b84514e5e8c11a42aaab6cda7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e635130771724202984be320ae64dd674eaf63fdb4801713ba4f907a139cea8b
+size 336772
diff --git a/assets/images/np15_b07k7k8yf9.png b/assets/images/np15_b07k7k8yf9.png
new file mode 100644
index 0000000000000000000000000000000000000000..b566e446fb0f737a2e51a2b140cc650044cf4d6c
Binary files /dev/null and b/assets/images/np15_b07k7k8yf9.png differ
diff --git a/assets/images/np2_2b7c6f4109fa4f84919bd30c367c663b.png b/assets/images/np2_2b7c6f4109fa4f84919bd30c367c663b.png
new file mode 100644
index 0000000000000000000000000000000000000000..d873f0973e7e462ab9deb3f38777fde98fc4dd5f
--- /dev/null
+++ b/assets/images/np2_2b7c6f4109fa4f84919bd30c367c663b.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b4be0f7ad492ff3a309146c8d2a8d677b0226f07fb04a4afb524da0ff4478095
+size 389018
diff --git a/assets/images/np2_3cd2cdea87f34bd7acb096f5178b3760.png b/assets/images/np2_3cd2cdea87f34bd7acb096f5178b3760.png
new file mode 100644
index 0000000000000000000000000000000000000000..5a1073a08e3a6d9e84bc1bc4e3e47eda57938627
--- /dev/null
+++ b/assets/images/np2_3cd2cdea87f34bd7acb096f5178b3760.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:30a8575bacaa7eb12637487109e8cd1f619b007a67f17b3a880cb31ded96d968
+size 232387
diff --git a/assets/images/np3_2f6ab901c5a84ed6bbdf85a67b22a2ee.png b/assets/images/np3_2f6ab901c5a84ed6bbdf85a67b22a2ee.png
new file mode 100644
index 0000000000000000000000000000000000000000..a689d0d619905c80f78d00ee11d6f8b500a34676
--- /dev/null
+++ b/assets/images/np3_2f6ab901c5a84ed6bbdf85a67b22a2ee.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:119932ee0164ad7df229b73e54573fbbe774ae3d5f789f6073e52e20a283975b
+size 376376
diff --git a/assets/images/np3_5bbc755b07924fcba6ccd702191f90df.png b/assets/images/np3_5bbc755b07924fcba6ccd702191f90df.png
new file mode 100644
index 0000000000000000000000000000000000000000..a86195c51f68c67df461b484c6f57371067418ef
--- /dev/null
+++ b/assets/images/np3_5bbc755b07924fcba6ccd702191f90df.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:76caf3706390576c82da17a0ca487fc93d33150a5be8d4457961e4d314ceb237
+size 269552
diff --git a/assets/images/np3_7ba4129133dc46d0818930adb53c0315.png b/assets/images/np3_7ba4129133dc46d0818930adb53c0315.png
new file mode 100644
index 0000000000000000000000000000000000000000..c5d9528610af86a716a1dd2496a73edf01f0cd01
--- /dev/null
+++ b/assets/images/np3_7ba4129133dc46d0818930adb53c0315.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:202da0673a23e227e03f1fe715e7ffd728f4732140fff14607982e4142687922
+size 1173643
diff --git a/assets/images/np3_7c00a1667b2c4413b4dcb24500f05892.png b/assets/images/np3_7c00a1667b2c4413b4dcb24500f05892.png
new file mode 100644
index 0000000000000000000000000000000000000000..cc9bd523c0fed3fc35b77f9f2a5e6f74d45a1b85
--- /dev/null
+++ b/assets/images/np3_7c00a1667b2c4413b4dcb24500f05892.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ce2c8ef23f28b23b613dfc5f7ed588d4e0149c855f6e6a854b251297bf6fdff6
+size 307252
diff --git a/assets/images/np3_be2e5af368e748b88871ed5776e951bb.png b/assets/images/np3_be2e5af368e748b88871ed5776e951bb.png
new file mode 100644
index 0000000000000000000000000000000000000000..b168a3f85c8c5324b64e398d925bcc4e7480749b
--- /dev/null
+++ b/assets/images/np3_be2e5af368e748b88871ed5776e951bb.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:018e5b0b0c3b67653401c26cadd0fa9d79e36a143d8d69735ccf1419295ba872
+size 245093
diff --git a/assets/images/np3_c25308e36b3a4c6d9745e01fb34a93a1.png b/assets/images/np3_c25308e36b3a4c6d9745e01fb34a93a1.png
new file mode 100644
index 0000000000000000000000000000000000000000..85fb3248d215f317f0eb552822b36b1c9bd86b5e
--- /dev/null
+++ b/assets/images/np3_c25308e36b3a4c6d9745e01fb34a93a1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9922331dd7a7a3be680ef0aab4a3ea53921429db99922fae64aa20bece6226ff
+size 800553
diff --git a/assets/images/np3_d2e1aa2d2b4b424282d3b643fc8169cf.png b/assets/images/np3_d2e1aa2d2b4b424282d3b643fc8169cf.png
new file mode 100644
index 0000000000000000000000000000000000000000..35738897334ec72b29ae7b98b9a584b4a56da83d
--- /dev/null
+++ b/assets/images/np3_d2e1aa2d2b4b424282d3b643fc8169cf.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0b0d7069bf3f54ca39f78c92671d8ae1cb2e7ef85ef81ac5605f237f9d501db3
+size 388444
diff --git a/assets/images/np4_2444ea17f3a448b1bb7e2a74b276f015.png b/assets/images/np4_2444ea17f3a448b1bb7e2a74b276f015.png
new file mode 100644
index 0000000000000000000000000000000000000000..6d8880725220ec8684bea7e2571e73a4788b3583
--- /dev/null
+++ b/assets/images/np4_2444ea17f3a448b1bb7e2a74b276f015.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c7b7a8863912a15b7fbd59272d5ec07caa25dd30483cb07a0a10b01e30301dda
+size 290712
diff --git a/assets/images/np4_4d6e4d3dd9194633ac86285ada1017ad.png b/assets/images/np4_4d6e4d3dd9194633ac86285ada1017ad.png
new file mode 100644
index 0000000000000000000000000000000000000000..1b0859b9af6f4b101e4001656ded1a5ad433237c
--- /dev/null
+++ b/assets/images/np4_4d6e4d3dd9194633ac86285ada1017ad.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3ecd3a9ea69a2705e49c9f82b55c6cd1cb4034693ca9867ee9bb9eef336106ae
+size 762804
diff --git a/assets/images/np4_7bd5d25aa77b4fb18e780d7a4c97d342.png b/assets/images/np4_7bd5d25aa77b4fb18e780d7a4c97d342.png
new file mode 100644
index 0000000000000000000000000000000000000000..645ec1ae58b5e19fdd0e4baf6373946c7e0a2974
--- /dev/null
+++ b/assets/images/np4_7bd5d25aa77b4fb18e780d7a4c97d342.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:143b11944e54a87177bb4c4722ef6918855f4845864fd9a65b02fc8de9592462
+size 821337
diff --git a/assets/images/np4_80e586911397457aa9245eed1eb03abe.png b/assets/images/np4_80e586911397457aa9245eed1eb03abe.png
new file mode 100644
index 0000000000000000000000000000000000000000..fbb3bafc188d66d5433c662e4a6be4b2c8434077
--- /dev/null
+++ b/assets/images/np4_80e586911397457aa9245eed1eb03abe.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fe08ad9a702ff2a34ee0cf1f218975c98d9adddf295412b413b2ddaafa94294e
+size 475737
diff --git a/assets/images/np4_ca27601df1384a7aa152baacfb072306.png b/assets/images/np4_ca27601df1384a7aa152baacfb072306.png
new file mode 100644
index 0000000000000000000000000000000000000000..c8ebb5964b65be19654aa199186424d761922df4
--- /dev/null
+++ b/assets/images/np4_ca27601df1384a7aa152baacfb072306.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:82a7cda5209ca501fa989123605e794f62c89c0a7dbb19158f376e4f1442a9d9
+size 201013
diff --git a/assets/images/np5_23ae06bb5cf84e13ae973721fa5f5625.png b/assets/images/np5_23ae06bb5cf84e13ae973721fa5f5625.png
new file mode 100644
index 0000000000000000000000000000000000000000..cf2fcb3dd8dcea31c2b01df3c240e2927c7730c8
--- /dev/null
+++ b/assets/images/np5_23ae06bb5cf84e13ae973721fa5f5625.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8defc96987615208f87d31a7a527533491019d10e71b117cf1df564c4648352a
+size 336412
diff --git a/assets/images/np5_b81f29e567ea4db48014f89c9079e403.png b/assets/images/np5_b81f29e567ea4db48014f89c9079e403.png
new file mode 100644
index 0000000000000000000000000000000000000000..20bd9a237b21e7cca1838155a39d8cb46e514fd5
--- /dev/null
+++ b/assets/images/np5_b81f29e567ea4db48014f89c9079e403.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d45e657d14c93844ef02f73e9ecc9d8997d49c7b5490e5de95e5fff69620d235
+size 342733
diff --git a/assets/images/np6_b7295d1f9c484a84a53f7ba62ead149e.png b/assets/images/np6_b7295d1f9c484a84a53f7ba62ead149e.png
new file mode 100644
index 0000000000000000000000000000000000000000..5a4aa92f076db9c84b75b0b0e6bae6dd8992922c
--- /dev/null
+++ b/assets/images/np6_b7295d1f9c484a84a53f7ba62ead149e.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:48f5fe3099b3832a74f4190c780cee2865442a2020edfc419f2fe025abec4ff0
+size 1089146
diff --git a/assets/images/np6_bd2716459c774dc48fc793f1b76511e8.png b/assets/images/np6_bd2716459c774dc48fc793f1b76511e8.png
new file mode 100644
index 0000000000000000000000000000000000000000..75c478568063506a9d3aa87c39f8323593e79ad8
--- /dev/null
+++ b/assets/images/np6_bd2716459c774dc48fc793f1b76511e8.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3e5d8f25ecaecdb87976ba6b0d24a06a9877937c06741921aba93e852d53579c
+size 256717
diff --git a/assets/images/np7_1749ad163235411295ed3342d024f1ac.png b/assets/images/np7_1749ad163235411295ed3342d024f1ac.png
new file mode 100644
index 0000000000000000000000000000000000000000..2de92cb7ec537a9190a29535328bfd23f8023665
--- /dev/null
+++ b/assets/images/np7_1749ad163235411295ed3342d024f1ac.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2e9e5f1eb05bae290546a01b49d58da903bd119a3d0a3fac6aa9bd7b1d1c4c08
+size 311775
diff --git a/assets/images/np7_1c004909dedb4ebe8db69b4d7b077434.png b/assets/images/np7_1c004909dedb4ebe8db69b4d7b077434.png
new file mode 100644
index 0000000000000000000000000000000000000000..2687cc38f0bce0264b15fb0618761eb226de6e64
--- /dev/null
+++ b/assets/images/np7_1c004909dedb4ebe8db69b4d7b077434.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:311ca797fd35396514da6bd7cdeba43e7cb6b8e98c667805c867a882461dde6f
+size 492676
diff --git a/assets/images/np7_8d3e568f059244e19f3b5f7e789cccb2.png b/assets/images/np7_8d3e568f059244e19f3b5f7e789cccb2.png
new file mode 100644
index 0000000000000000000000000000000000000000..d8746a37bc16dcc1a4872981c82f6a887fad0eea
--- /dev/null
+++ b/assets/images/np7_8d3e568f059244e19f3b5f7e789cccb2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:394bc22d9a790d34834e6e1e86f862b3de6c7d007d5341210c1a3c7baa7d6709
+size 104446
diff --git a/assets/images/np8_a7d7c17eadc54aa5a0094415025463ff.png b/assets/images/np8_a7d7c17eadc54aa5a0094415025463ff.png
new file mode 100644
index 0000000000000000000000000000000000000000..dbb6c31bb680c9c97ddc53152a342ef462ca55e8
--- /dev/null
+++ b/assets/images/np8_a7d7c17eadc54aa5a0094415025463ff.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f6074adbed8cccb9c366a74d381b88efd67bfa0c7ca4efd640b4de086640ea4c
+size 104937
diff --git a/assets/images/np8_d84dc4fc47614a4687a774152b343ddd.png b/assets/images/np8_d84dc4fc47614a4687a774152b343ddd.png
new file mode 100644
index 0000000000000000000000000000000000000000..732b6c258c6652d1eef117b1fdcef7f14f042583
--- /dev/null
+++ b/assets/images/np8_d84dc4fc47614a4687a774152b343ddd.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a1c8cf67994b41050f6d4f10304976f51c7dc76370acdddcdd2d3315ed7742b1
+size 361523
diff --git a/assets/images/np9_5cc2b8234ff04a2aa29d5f0a0393ed0c.png b/assets/images/np9_5cc2b8234ff04a2aa29d5f0a0393ed0c.png
new file mode 100644
index 0000000000000000000000000000000000000000..c7321619dc130843a3ad7be72044be49ac4a3274
--- /dev/null
+++ b/assets/images/np9_5cc2b8234ff04a2aa29d5f0a0393ed0c.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:208ebdc84b344eb863f93a254e0e4bc8a9e4499db5026569f145d79505a9fc8c
+size 123938
diff --git a/assets/images/np9_8031fc7690e640038ff6a6766e97f19d.png b/assets/images/np9_8031fc7690e640038ff6a6766e97f19d.png
new file mode 100644
index 0000000000000000000000000000000000000000..a592af084862e315940cc2538cf2902646bc73ed
--- /dev/null
+++ b/assets/images/np9_8031fc7690e640038ff6a6766e97f19d.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c286e0c1e9a945fba3a49bbf6a121b820b00c248215978fdef5c588e2a6cd447
+size 614414
diff --git a/assets/images/np9_94b4b6d936214b3f988bd4b520da4f53.png b/assets/images/np9_94b4b6d936214b3f988bd4b520da4f53.png
new file mode 100644
index 0000000000000000000000000000000000000000..de7eb6c03bab85f4d742c9f036a799ab1c4670e4
--- /dev/null
+++ b/assets/images/np9_94b4b6d936214b3f988bd4b520da4f53.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4a020950d84537cd0d4d70cc86fc9c1b55f0df9d2e9765cfd2172502ba29ff4b
+size 335052
diff --git a/assets/images/np9_dd0ec139989c430fb6572fa024ae1c20.png b/assets/images/np9_dd0ec139989c430fb6572fa024ae1c20.png
new file mode 100644
index 0000000000000000000000000000000000000000..93f0bca6d8c27f8d2c9a3d7092c3f9dd64789d6a
--- /dev/null
+++ b/assets/images/np9_dd0ec139989c430fb6572fa024ae1c20.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:070eecaafa27714b2cb5696e082d4521898c1c2f909658cc5fa65dd0be493e47
+size 185856
diff --git a/assets/objects/scissors.glb b/assets/objects/scissors.glb
new file mode 100644
index 0000000000000000000000000000000000000000..991041b6767e8b36baedfbed3ae33518e5b14ec4
--- /dev/null
+++ b/assets/objects/scissors.glb
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ac7b9d2218cb6db57c90581bf43b10ccf641bdd0f5eeda9de38770341f01bcef
+size 9250644
diff --git a/assets/objects/sword.glb b/assets/objects/sword.glb
new file mode 100644
index 0000000000000000000000000000000000000000..3df980db2e379834cac8d0602279602b0e03f98c
--- /dev/null
+++ b/assets/objects/sword.glb
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:731fdcff8f129f2ebdb4c5b8967091e0d2041c7ce5064d5061ef628bcdb7e0f2
+size 413064
diff --git a/assets/robot.gif b/assets/robot.gif
new file mode 100644
index 0000000000000000000000000000000000000000..70368d4a70755ce8cf651d4d8f00a60c5bf66643
--- /dev/null
+++ b/assets/robot.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b2a72d5f23f092f08ff3f5594af12c24faa9e679f7e6cdbc65ebd96f5837c4ef
+size 2462693
diff --git a/assets/teaser.png b/assets/teaser.png
new file mode 100644
index 0000000000000000000000000000000000000000..885822c923f788fa0f839a37f8ba98fe39b71936
--- /dev/null
+++ b/assets/teaser.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:823430086ba3135328503249cdda694e82db4f3a36c60f3c180ecea6ec63d57a
+size 3421937
diff --git a/configs/mp16_nt1024.yaml b/configs/mp16_nt1024.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0d44868cb0854e54f751a8cca3f46a6e947f2a8c
--- /dev/null
+++ b/configs/mp16_nt1024.yaml
@@ -0,0 +1,77 @@
+model:
+ pretrained_model_name_or_path: 'pretrained_weights/TripoSG'
+ vae:
+ num_tokens: 1024
+ transformer:
+ enable_local_cross_attn: true
+ global_attn_block_ids: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
+ global_attn_block_id_range: null # The average should be 10 for unet-skipping
+
+
+dataset:
+ config:
+ - 'datasets/object_part_configs.json' # Modify this path if you use your own dataset
+ training_ratio: 0.9
+ min_num_parts: 1
+ max_num_parts: 16
+ max_iou_mean: 0.2
+ max_iou_max: 0.2
+ shuffle_parts: true
+ object_ratio: 0.3
+ rotating_ratio: 0.2
+ ratating_degree: 10
+
+optimizer:
+ name: "adamw"
+ lr: 5e-5
+ betas:
+ - 0.9
+ - 0.999
+ weight_decay: 0.01
+ eps: 1.e-8
+
+lr_scheduler:
+ name: "constant_warmup"
+ num_warmup_steps: 1000
+
+train:
+ batch_size_per_gpu: 32
+ epochs: 10
+ grad_checkpoint: true
+ weighting_scheme: "logit_normal"
+ logit_mean: 0.0
+ logit_std: 1.0
+ mode_scale: 1.29
+ cfg_dropout_prob: 0.1
+ training_objective: "-v"
+ log_freq: 1
+ early_eval_freq: 500
+ early_eval: 1000
+ eval_freq: 1000
+ save_freq: 2000
+ eval_freq_epoch: 5
+ save_freq_epoch: 10
+ ema_kwargs:
+ decay: 0.9999
+ use_ema_warmup: true
+ inv_gamma: 1.
+ power: 0.75
+
+val:
+ batch_size_per_gpu: 1
+ nrow: 4
+ min_num_parts: 2
+ max_num_parts: 8
+ num_inference_steps: 50
+ max_num_expanded_coords: 1e8
+ use_flash_decoder: false
+ rendering:
+ radius: 4.0
+ num_views: 36
+ fps: 18
+ metric:
+ cd_num_samples: 204800
+ cd_metric: "l2"
+ f1_score_threshold: 0.1
+ default_cd: 1e6
+ default_f1: 0.0
\ No newline at end of file
diff --git a/configs/mp16_nt512.yaml b/configs/mp16_nt512.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f5773d3c4da085c4ee24f0d35dd55835662651d4
--- /dev/null
+++ b/configs/mp16_nt512.yaml
@@ -0,0 +1,77 @@
+model:
+ pretrained_model_name_or_path: 'pretrained_weights/TripoSG'
+ vae:
+ num_tokens: 512
+ transformer:
+ enable_local_cross_attn: true
+ global_attn_block_ids: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
+ global_attn_block_id_range: null # The average should be 10 for unet-skipping
+
+
+dataset:
+ config:
+ - 'datasets/object_part_configs.json' # Modify this path if you use your own dataset
+ training_ratio: 0.9
+ min_num_parts: 1
+ max_num_parts: 16
+ max_iou_mean: 0.5
+ max_iou_max: 0.5
+ shuffle_parts: true
+ object_ratio: 0.3
+ rotating_ratio: 0.2
+ ratating_degree: 10
+
+optimizer:
+ name: "adamw"
+ lr: 5e-5
+ betas:
+ - 0.9
+ - 0.999
+ weight_decay: 0.01
+ eps: 1.e-8
+
+lr_scheduler:
+ name: "constant_warmup"
+ num_warmup_steps: 1000
+
+train:
+ batch_size_per_gpu: 32
+ epochs: 10
+ grad_checkpoint: true
+ weighting_scheme: "logit_normal"
+ logit_mean: 0.0
+ logit_std: 1.0
+ mode_scale: 1.29
+ cfg_dropout_prob: 0.1
+ training_objective: "-v"
+ log_freq: 1
+ early_eval_freq: 500
+ early_eval: 1000
+ eval_freq: 1000
+ save_freq: 2000
+ eval_freq_epoch: 5
+ save_freq_epoch: 10
+ ema_kwargs:
+ decay: 0.9999
+ use_ema_warmup: true
+ inv_gamma: 1.
+ power: 0.75
+
+val:
+ batch_size_per_gpu: 1
+ nrow: 4
+ min_num_parts: 2
+ max_num_parts: 8
+ num_inference_steps: 50
+ max_num_expanded_coords: 1e8
+ use_flash_decoder: false
+ rendering:
+ radius: 4.0
+ num_views: 36
+ fps: 18
+ metric:
+ cd_num_samples: 204800
+ cd_metric: "l2"
+ f1_score_threshold: 0.1
+ default_cd: 1e6
+ default_f1: 0.0
\ No newline at end of file
diff --git a/configs/mp8_nt512.yaml b/configs/mp8_nt512.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..40547f122defc9ced4fb8ffebd012939a36a6701
--- /dev/null
+++ b/configs/mp8_nt512.yaml
@@ -0,0 +1,77 @@
+model:
+ pretrained_model_name_or_path: 'pretrained_weights/TripoSG'
+ vae:
+ num_tokens: 512
+ transformer:
+ enable_local_cross_attn: true
+ global_attn_block_ids: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
+ global_attn_block_id_range: null # The average should be 10 for unet-skipping
+
+
+dataset:
+ config:
+ - 'datasets/object_part_configs.json' # Modify this path if you use your own dataset
+ training_ratio: 0.9
+ min_num_parts: 1
+ max_num_parts: 8
+ max_iou_mean: 0.5
+ max_iou_max: 0.5
+ shuffle_parts: true
+ object_ratio: 0.3
+ rotating_ratio: 0.2
+ ratating_degree: 10
+
+optimizer:
+ name: "adamw"
+ lr: 1e-4
+ betas:
+ - 0.9
+ - 0.999
+ weight_decay: 0.01
+ eps: 1.e-8
+
+lr_scheduler:
+ name: "constant_warmup"
+ num_warmup_steps: 1000
+
+train:
+ batch_size_per_gpu: 32
+ epochs: 10
+ grad_checkpoint: true
+ weighting_scheme: "logit_normal"
+ logit_mean: 0.0
+ logit_std: 1.0
+ mode_scale: 1.29
+ cfg_dropout_prob: 0.1
+ training_objective: "-v"
+ log_freq: 1
+ early_eval_freq: 500
+ early_eval: 1000
+ eval_freq: 1000
+ save_freq: 2000
+ eval_freq_epoch: 5
+ save_freq_epoch: 10
+ ema_kwargs:
+ decay: 0.9999
+ use_ema_warmup: true
+ inv_gamma: 1.
+ power: 0.75
+
+val:
+ batch_size_per_gpu: 1
+ nrow: 4
+ min_num_parts: 2
+ max_num_parts: 8
+ num_inference_steps: 50
+ max_num_expanded_coords: 1e8
+ use_flash_decoder: false
+ rendering:
+ radius: 4.0
+ num_views: 36
+ fps: 18
+ metric:
+ cd_num_samples: 204800
+ cd_metric: "l2"
+ f1_score_threshold: 0.1
+ default_cd: 1e6
+ default_f1: 0.0
\ No newline at end of file
diff --git a/datasets/README.md b/datasets/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..cf5d6c19e62ae0d82786d2ea971c6c89b11496a6
--- /dev/null
+++ b/datasets/README.md
@@ -0,0 +1,64 @@
+# Dataset Preparation
+We provide the data preprocessing pipeline for PartCrafter. By following the instructions, you can generate the training data from the raw GLB data. While we are considering releasing the preprocessed dataset, please note that it may take some time before it becomes available.
+
+## Download Raw Data
+Our final model uses a subset of [Objaverse](https://huggingface.co/datasets/allenai/objaverse) provided by [LGM](https://github.com/ashawkey/objaverse_filter) and [Amazon Berkeley Objects (ABO) Dataset](https://amazon-berkeley-objects.s3.amazonaws.com/index.html). Please download the raw GLB files according to their instructions. You can also use other source of data.
+
+## Data Preprocess
+We provide several scripts to preprocess the raw GLB files [here](./preprocess/). These scripts are minimal implementations and illustrate the whole preprocessing pipeline on a single 3D object.
+
+1. Sample points from mesh surface
+```
+python datasets/preprocess/mesh_to_point.py --input assets/objects/scissors.glb --output preprocessed_data
+```
+
+2. Render images
+```
+python datasets/preprocess/render.py --input assets/objects/scissors.glb --output preprocessed_data
+```
+
+3. Remove background for rendered images and resize to 90%
+```
+python datasets/preprocess/rmbg.py --input preprocessed_data/scissors/rendering.png --output preprocessed_data
+```
+
+4. (Optional) Calculate IoU
+```
+python datasets/preprocess/calculate_iou.py --input assets/objects/scissors.glb --output preprocessed_data
+```
+After preprocessing, you can generate a dataset configuration file according to the example configuration file with your own data path.
+
+To preprocess a folder of meshes, run
+```
+python datasets/preprocess/preprocess.py --input assets/objects --output preprocessed_data
+```
+This will also generate a configuration file in `./preprocessed_data/object_part_configs.json`.
+
+## Dataset Configuration
+The training code requires specific format of dataset configuration. I provide an example configuration [here](example_configs.json). You can use it as a template to configure your own dataset. A minimal legal configuration file should be like:
+
+```
+[
+ {
+ "mesh_path": "/path/to/object.glb",
+ "surface_path": "/path/to/object.npy",
+ "image_path": "/path/to/object.png",
+ "num_parts": 4,
+ "iou_mean": 0.5,
+ "iou_max": 0.9,
+ "valid": true
+ },
+ {
+ ...
+ },
+ ...
+]
+```
+Explaination:
+- `mesh_path`: The path to the GLB file of the object.
+- `surface_path`: The path to the npy file of the object surface points.
+- `image_path`: The path to the rendered image of the object (after removing background).
+- `num_parts`: The number of parts of the object.
+- `iou_mean`: The mean IoU of the object parts.
+- `iou_max`: The max IoU of the object parts.
+- `valid`: Whether the object is valid. If set to false, the object will be filtered out during training.
\ No newline at end of file
diff --git a/datasets/example_configs.json b/datasets/example_configs.json
new file mode 100644
index 0000000000000000000000000000000000000000..2577913f6c68b1eb431ed48a466efa2e45988d46
--- /dev/null
+++ b/datasets/example_configs.json
@@ -0,0 +1,38 @@
+[
+ {
+ "dataset": "objaverse",
+ "file": "0007a7c8fcb44074b20fa4e14b8730a6.glb",
+ "folder": "000",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "/datasets/objaverse/000/0007a7c8fcb44074b20fa4e14b8730a6.glb",
+ "surface_path": "/datasets/objaverse_point/000/0007a7c8fcb44074b20fa4e14b8730a6.npy",
+ "image_path": "/datasets/objaverse_rendering_rmbg/000/0007a7c8fcb44074b20fa4e14b8730a6.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "dataset": "objaverse",
+ "file": "e05d63a78f2c44c5a64fbc419998e603.glb",
+ "folder": "001",
+ "num_parts": 7,
+ "valid": true,
+ "mesh_path": "/datasets/objaverse/001/e05d63a78f2c44c5a64fbc419998e603.glb",
+ "surface_path": "/datasets/objaverse_point/001/e05d63a78f2c44c5a64fbc419998e603.npy",
+ "image_path": "/datasets/objaverse_rendering_rmbg/001/e05d63a78f2c44c5a64fbc419998e603.png",
+ "iou_mean": 0.0005096137625985897,
+ "iou_max": 0.0022603978300180833
+ },
+ {
+ "dataset": "objaverse",
+ "file": "e0a9d033f27043339eb724ec5b7f5fdc.glb",
+ "folder": "002",
+ "num_parts": 3,
+ "valid": true,
+ "mesh_path": "/datasets/objaverse/002/e0a9d033f27043339eb724ec5b7f5fdc.glb",
+ "surface_path": "/datasets/objaverse_point/002/e0a9d033f27043339eb724ec5b7f5fdc.npy",
+ "image_path": "/datasets/objaverse_rendering_rmbg/002/e0a9d033f27043339eb724ec5b7f5fdc.png",
+ "iou_mean": 0.0683491062039958,
+ "iou_max": 0.20504731861198738
+ }
+]
\ No newline at end of file
diff --git a/datasets/object_part_configs.json b/datasets/object_part_configs.json
new file mode 100644
index 0000000000000000000000000000000000000000..e5701ff130567561ca4971d310f0dfba662b61fb
--- /dev/null
+++ b/datasets/object_part_configs.json
@@ -0,0 +1,2002 @@
+[
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "scissors.glb",
+ "num_parts": 2,
+ "valid": true,
+ "mesh_path": "assets/objects/scissors.glb",
+ "surface_path": "preprocessed_data/scissors/points.npy",
+ "image_path": "preprocessed_data/scissors/rendering_rmbg.png",
+ "iou_mean": 0.021141649048625793,
+ "iou_max": 0.021141649048625793
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ },
+ {
+ "file": "sword.glb",
+ "num_parts": 1,
+ "valid": true,
+ "mesh_path": "assets/objects/sword.glb",
+ "surface_path": "preprocessed_data/sword/points.npy",
+ "image_path": "preprocessed_data/sword/rendering_rmbg.png",
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ }
+]
\ No newline at end of file
diff --git a/datasets/preprocess/calculate_iou.py b/datasets/preprocess/calculate_iou.py
new file mode 100644
index 0000000000000000000000000000000000000000..76e13366cb529b31d7cee3ee4360a23662df2413
--- /dev/null
+++ b/datasets/preprocess/calculate_iou.py
@@ -0,0 +1,42 @@
+import os
+import trimesh
+import numpy as np
+import argparse
+import json
+
+from src.utils.data_utils import normalize_mesh
+from src.utils.metric_utils import compute_IoU_for_scene
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--input', type=str, default='assets/objects/scissors.glb')
+ parser.add_argument('--output', type=str, default='preprocessed_data')
+ args = parser.parse_args()
+
+ input_path = args.input
+ output_path = args.output
+
+ assert os.path.exists(input_path), f'{input_path} does not exist'
+
+ mesh_name = os.path.basename(input_path).split('.')[0]
+ output_path = os.path.join(output_path, mesh_name)
+ if not os.path.exists(output_path):
+ os.makedirs(output_path)
+
+ config = {
+ 'iou_mean': 0.0,
+ 'iou_max': 0.0,
+ 'iou_list': [],
+ }
+ mesh = normalize_mesh(trimesh.load(input_path, process=False))
+ try:
+ iou_list = compute_IoU_for_scene(mesh, return_type='iou_list')
+ config['iou_list'] = iou_list
+ config['iou_mean'] = np.mean(iou_list)
+ config['iou_max'] = np.max(iou_list)
+ except:
+ config['iou_list'] = []
+ config['iou_mean'] = 0.0
+ config['iou_max'] = 0.0
+
+ json.dump(config, open(os.path.join(output_path, f'iou.json'), 'w'), indent=4)
\ No newline at end of file
diff --git a/datasets/preprocess/mesh_to_point.py b/datasets/preprocess/mesh_to_point.py
new file mode 100644
index 0000000000000000000000000000000000000000..273d0c423eb27fb1fb53a7212f7fa4bc7bbaa2ca
--- /dev/null
+++ b/datasets/preprocess/mesh_to_point.py
@@ -0,0 +1,52 @@
+import os
+import trimesh
+import numpy as np
+import argparse
+import json
+
+from src.utils.data_utils import scene_to_parts, mesh_to_surface, normalize_mesh
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--input', type=str, default='assets/objects/scissors.glb')
+ parser.add_argument('--output', type=str, default='preprocessed_data')
+ args = parser.parse_args()
+
+ input_path = args.input
+ output_path = args.output
+
+ assert os.path.exists(input_path), f'{input_path} does not exist'
+
+ mesh_name = os.path.basename(input_path).split('.')[0]
+ output_path = os.path.join(output_path, mesh_name)
+ if not os.path.exists(output_path):
+ os.makedirs(output_path)
+
+ config = {
+ "num_parts": 0,
+ }
+
+ # sample points from mesh surface
+ mesh = trimesh.load(input_path, process=False)
+ mesh = normalize_mesh(mesh)
+ config["num_parts"] = len(mesh.geometry)
+ if config["num_parts"] > 1 and config["num_parts"] <= 16:
+ parts = scene_to_parts(
+ mesh,
+ return_type="point",
+ normalize=False
+ )
+ else:
+ parts = []
+ mesh = mesh.to_geometry()
+ object = mesh_to_surface(mesh, return_dict=True)
+ datas = {
+ "object": object,
+ "parts": parts,
+ }
+ # save points
+ np.save(os.path.join(output_path, 'points.npy'), datas)
+
+ # save config
+ with open(os.path.join(output_path, 'num_parts.json'), 'w') as f:
+ json.dump(config, f, indent=4)
\ No newline at end of file
diff --git a/datasets/preprocess/preprocess.py b/datasets/preprocess/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a83c365eecd2dc0df36092cbaa3a4ba90694631
--- /dev/null
+++ b/datasets/preprocess/preprocess.py
@@ -0,0 +1,68 @@
+import os
+import json
+import argparse
+import time
+from tqdm import tqdm
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--input', type=str, default='assets/objects')
+ parser.add_argument('--output', type=str, default='preprocessed_data')
+ args = parser.parse_args()
+
+ input_path = args.input
+ output_path = args.output
+
+ assert os.path.exists(input_path), f'{input_path} does not exist'
+
+ if not os.path.exists(output_path):
+ os.makedirs(output_path)
+
+ for mesh_name in tqdm(os.listdir(input_path)):
+ mesh_path = os.path.join(input_path, mesh_name)
+ # 1. Sample points from mesh surface
+ os.system(f"python datasets/preprocess/mesh_to_point.py --input {mesh_path} --output {output_path}")
+ # 2. Render images
+ os.system(f"python datasets/preprocess/render.py --input {mesh_path} --output {output_path}")
+ # 3. Remove background for rendered images and resize to 90%
+ export_mesh_folder = os.path.join(output_path, mesh_name.replace('.glb', ''))
+ export_rendering_path = os.path.join(export_mesh_folder, 'rendering.png')
+ os.system(f"python datasets/preprocess/rmbg.py --input {export_rendering_path} --output {output_path}")
+ # 4. (Optional) Calculate IoU
+ os.system(f"python datasets/preprocess/calculate_iou.py --input {mesh_path} --output {output_path}")
+ time.sleep(1)
+
+ # generate configs
+ configs = []
+ for mesh_name in tqdm(os.listdir(input_path)):
+ mesh_path = os.path.join(output_path, mesh_name.replace('.glb', ''))
+ num_parts_path = os.path.join(mesh_path, 'num_parts.json')
+ surface_path = os.path.join(mesh_path, 'points.npy')
+ image_path = os.path.join(mesh_path, 'rendering_rmbg.png')
+ iou_path = os.path.join(mesh_path, 'iou.json')
+ config = {
+ "file": mesh_name,
+ "num_parts": 0,
+ "valid": False,
+ "mesh_path": os.path.join(input_path, mesh_name),
+ "surface_path": None,
+ "image_path": None,
+ "iou_mean": 0.0,
+ "iou_max": 0.0
+ }
+ try:
+ config["num_parts"] = json.load(open(num_parts_path))['num_parts']
+ iou_config = json.load(open(iou_path))
+ config['iou_mean'] = iou_config['iou_mean']
+ config['iou_max'] = iou_config['iou_max']
+ assert os.path.exists(surface_path)
+ config['surface_path'] = surface_path
+ assert os.path.exists(image_path)
+ config['image_path'] = image_path
+ config['valid'] = True
+ configs.append(config)
+ except:
+ continue
+
+ configs_path = os.path.join(output_path, 'object_part_configs.json')
+ json.dump(configs, open(configs_path, 'w'), indent=4)
\ No newline at end of file
diff --git a/datasets/preprocess/render.py b/datasets/preprocess/render.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d9caf141d476b058f8af8a615e0bddaa3a6e024
--- /dev/null
+++ b/datasets/preprocess/render.py
@@ -0,0 +1,41 @@
+import os
+import trimesh
+import numpy as np
+import argparse
+import json
+
+from src.utils.data_utils import normalize_mesh
+from src.utils.render_utils import render_single_view
+
+RADIUS = 4
+IMAGE_SIZE = (2048, 2048)
+LIGHT_INTENSITY = 2.5
+NUM_ENV_LIGHTS = 36
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--input', type=str, default='assets/objects/scissors.glb')
+ parser.add_argument('--output', type=str, default='preprocessed_data')
+ args = parser.parse_args()
+
+ input_path = args.input
+ output_path = args.output
+
+ assert os.path.exists(input_path), f'{input_path} does not exist'
+
+ mesh_name = os.path.basename(input_path).split('.')[0]
+ output_path = os.path.join(output_path, mesh_name)
+ if not os.path.exists(output_path):
+ os.makedirs(output_path)
+
+ mesh = normalize_mesh(trimesh.load(input_path, process=False))
+ mesh = mesh.to_geometry()
+ image = render_single_view(
+ mesh,
+ radius=RADIUS,
+ image_size=IMAGE_SIZE,
+ light_intensity=LIGHT_INTENSITY,
+ num_env_lights=NUM_ENV_LIGHTS,
+ return_type='pil'
+ )
+ image.save(os.path.join(output_path, f'rendering.png'))
\ No newline at end of file
diff --git a/datasets/preprocess/rmbg.py b/datasets/preprocess/rmbg.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbbef5fc8cc2ea77733f2bbc695bf37b8dcf5fc1
--- /dev/null
+++ b/datasets/preprocess/rmbg.py
@@ -0,0 +1,33 @@
+import os
+import trimesh
+import numpy as np
+import argparse
+import json
+import torch
+from huggingface_hub import snapshot_download
+
+from src.utils.image_utils import prepare_image
+from src.models.briarmbg import BriaRMBG
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--input', type=str, default='preprocessed_data/scissors/scissors.png')
+ parser.add_argument('--output', type=str, default='preprocessed_data')
+ args = parser.parse_args()
+
+ input_path = args.input
+ output_path = args.output
+
+ assert os.path.exists(input_path), f'{input_path} does not exist'
+
+ mesh_name = os.path.basename(os.path.dirname(input_path))
+ output_path = os.path.join(output_path, mesh_name)
+ if not os.path.exists(output_path):
+ os.makedirs(output_path)
+
+ rmbg_weights_dir = "pretrained_weights/RMBG-1.4"
+ snapshot_download(repo_id="briaai/RMBG-1.4", local_dir=rmbg_weights_dir)
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ rmbg_net = BriaRMBG.from_pretrained(rmbg_weights_dir).to(device)
+ rendering_rmbg = prepare_image(input_path, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net, device=device)
+ rendering_rmbg.save(os.path.join(output_path, f'rendering_rmbg.png'))
\ No newline at end of file
diff --git a/scripts/inference_partcrafter.py b/scripts/inference_partcrafter.py
new file mode 100644
index 0000000000000000000000000000000000000000..864130f6eb62b7471ec12cdf06c89a89dde92445
--- /dev/null
+++ b/scripts/inference_partcrafter.py
@@ -0,0 +1,176 @@
+import argparse
+import os
+import sys
+from glob import glob
+import time
+from typing import Any, Union
+
+import numpy as np
+import torch
+import trimesh
+from huggingface_hub import snapshot_download
+from PIL import Image
+from accelerate.utils import set_seed
+
+from src.utils.data_utils import get_colored_mesh_composition, scene_to_parts, load_surfaces
+from src.utils.render_utils import render_views_around_mesh, render_normal_views_around_mesh, make_grid_for_images_or_videos, export_renderings
+from src.pipelines.pipeline_partcrafter import PartCrafterPipeline
+from src.utils.image_utils import prepare_image
+from src.models.briarmbg import BriaRMBG
+
+@torch.no_grad()
+def run_triposg(
+ pipe: Any,
+ image_input: Union[str, Image.Image],
+ num_parts: int,
+ rmbg_net: Any,
+ seed: int,
+ num_tokens: int = 1024,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.0,
+ max_num_expanded_coords: int = 1e9,
+ use_flash_decoder: bool = False,
+ rmbg: bool = False,
+ dtype: torch.dtype = torch.float16,
+ device: str = "cuda",
+) -> trimesh.Scene:
+
+ if rmbg:
+ img_pil = prepare_image(image_input, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
+ else:
+ img_pil = Image.open(image_input)
+ start_time = time.time()
+ outputs = pipe(
+ image=[img_pil] * num_parts,
+ attention_kwargs={"num_parts": num_parts},
+ num_tokens=num_tokens,
+ generator=torch.Generator(device=pipe.device).manual_seed(seed),
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ max_num_expanded_coords=max_num_expanded_coords,
+ use_flash_decoder=use_flash_decoder,
+ ).meshes
+ end_time = time.time()
+ print(f"Time elapsed: {end_time - start_time:.2f} seconds")
+ for i in range(len(outputs)):
+ if outputs[i] is None:
+ # If the generated mesh is None (decoing error), use a dummy mesh
+ outputs[i] = trimesh.Trimesh(vertices=[[0, 0, 0]], faces=[[0, 0, 0]])
+ return outputs, img_pil
+
+MAX_NUM_PARTS = 16
+
+if __name__ == "__main__":
+ device = "cuda"
+ dtype = torch.float16
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--image_path", type=str, required=True)
+ parser.add_argument("--num_parts", type=int, required=True, help="number of parts to generate")
+ parser.add_argument("--output_dir", type=str, default="./results")
+ parser.add_argument("--tag", type=str, default=None)
+ parser.add_argument("--seed", type=int, default=0)
+ parser.add_argument("--num_tokens", type=int, default=1024)
+ parser.add_argument("--num_inference_steps", type=int, default=50)
+ parser.add_argument("--guidance_scale", type=float, default=7.0)
+ parser.add_argument("--max_num_expanded_coords", type=int, default=1e9)
+ parser.add_argument("--use_flash_decoder", action="store_true")
+ parser.add_argument("--rmbg", action="store_true")
+ parser.add_argument("--render", action="store_true")
+ args = parser.parse_args()
+
+ assert 1 <= args.num_parts <= MAX_NUM_PARTS, f"num_parts must be in [1, {MAX_NUM_PARTS}]"
+
+ # download pretrained weights
+ partcrafter_weights_dir = "pretrained_weights/PartCrafter"
+ rmbg_weights_dir = "pretrained_weights/RMBG-1.4"
+ snapshot_download(repo_id="wgsxm/PartCrafter", local_dir=partcrafter_weights_dir)
+ snapshot_download(repo_id="briaai/RMBG-1.4", local_dir=rmbg_weights_dir)
+
+ # init rmbg model for background removal
+ rmbg_net = BriaRMBG.from_pretrained(rmbg_weights_dir).to(device)
+ rmbg_net.eval()
+
+ # init tripoSG pipeline
+ pipe: PartCrafterPipeline = PartCrafterPipeline.from_pretrained(partcrafter_weights_dir).to(device, dtype)
+
+ set_seed(args.seed)
+
+ # run inference
+ outputs, processed_image = run_triposg(
+ pipe,
+ image_input=args.image_path,
+ num_parts=args.num_parts,
+ rmbg_net=rmbg_net,
+ seed=args.seed,
+ num_tokens=args.num_tokens,
+ num_inference_steps=args.num_inference_steps,
+ guidance_scale=args.guidance_scale,
+ max_num_expanded_coords=args.max_num_expanded_coords,
+ use_flash_decoder=args.use_flash_decoder,
+ rmbg=args.rmbg,
+ dtype=dtype,
+ device=device,
+ )
+
+ if not os.path.exists(args.output_dir):
+ os.makedirs(args.output_dir)
+
+ if args.tag is None:
+ args.tag = time.strftime("%Y%m%d_%H_%M_%S")
+
+ export_dir = os.path.join(args.output_dir, args.tag)
+ os.makedirs(export_dir, exist_ok=True)
+
+ for i, mesh in enumerate(outputs):
+ mesh.export(os.path.join(export_dir, f"part_{i:02}.glb"))
+
+ merged_mesh = get_colored_mesh_composition(outputs)
+ merged_mesh.export(os.path.join(export_dir, "object.glb"))
+ print(f"Generated {len(outputs)} parts and saved to {export_dir}")
+
+ if args.render:
+ print("Start rendering...")
+ num_views = 36
+ radius = 4
+ fps = 18
+ rendered_images = render_views_around_mesh(
+ merged_mesh,
+ num_views=num_views,
+ radius=radius,
+ )
+ rendered_normals = render_normal_views_around_mesh(
+ merged_mesh,
+ num_views=num_views,
+ radius=radius,
+ )
+ rendered_grids = make_grid_for_images_or_videos(
+ [
+ [processed_image] * num_views,
+ rendered_images,
+ rendered_normals,
+ ],
+ nrow=3
+ )
+ export_renderings(
+ rendered_images,
+ os.path.join(export_dir, "rendering.gif"),
+ fps=fps,
+ )
+ export_renderings(
+ rendered_normals,
+ os.path.join(export_dir, "rendering_normal.gif"),
+ fps=fps,
+ )
+ export_renderings(
+ rendered_grids,
+ os.path.join(export_dir, "rendering_grid.gif"),
+ fps=fps,
+ )
+
+ rendered_image, rendered_normal, rendered_grid = rendered_images[0], rendered_normals[0], rendered_grids[0]
+ rendered_image.save(os.path.join(export_dir, "rendering.png"))
+ rendered_normal.save(os.path.join(export_dir, "rendering_normal.png"))
+ rendered_grid.save(os.path.join(export_dir, "rendering_grid.png"))
+ print("Rendering done.")
+
diff --git a/scripts/train_partcrafter.sh b/scripts/train_partcrafter.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c5e5f6ef91c1fcb05d516a8a92ebbda868be073c
--- /dev/null
+++ b/scripts/train_partcrafter.sh
@@ -0,0 +1,14 @@
+NUM_MACHINES=1
+NUM_LOCAL_GPUS=8
+MACHINE_RANK=0
+
+export WANDB_API_KEY="" # Modify this if you use wandb
+
+accelerate launch \
+ --num_machines $NUM_MACHINES \
+ --num_processes $(( $NUM_MACHINES * $NUM_LOCAL_GPUS )) \
+ --machine_rank $MACHINE_RANK \
+ src/train_partcrafter.py \
+ --pin_memory \
+ --allow_tf32 \
+$@
diff --git a/settings/requirements.txt b/settings/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..81ee26977f00721b7602a5bd21a84d7e6deecf1e
--- /dev/null
+++ b/settings/requirements.txt
@@ -0,0 +1,22 @@
+scikit-learn
+gpustat
+nvitop
+diffusers
+transformers
+einops
+huggingface_hub
+opencv-python
+trimesh
+omegaconf
+scikit-image
+numpy==1.26.4
+peft
+jaxtyping
+typeguard
+diso
+matplotlib
+imageio-ffmpeg
+pyrender
+deepspeed
+wandb[media]
+colormaps
\ No newline at end of file
diff --git a/settings/setup.sh b/settings/setup.sh
new file mode 100644
index 0000000000000000000000000000000000000000..91cf64023d268009a20cec5fe2929358c9a1089e
--- /dev/null
+++ b/settings/setup.sh
@@ -0,0 +1,3 @@
+pip install torch-cluster -f https://data.pyg.org/whl/torch-2.5.1+cu124.html
+pip install -r settings/requirements.txt
+sudo apt-get install libegl1 libegl1-mesa libgl1-mesa-dev -y # for rendering
\ No newline at end of file
diff --git a/src/datasets/__init__.py b/src/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..17f189e4d039fd6214f670f0c19eb2a353ca0e03
--- /dev/null
+++ b/src/datasets/__init__.py
@@ -0,0 +1,51 @@
+from src.utils.typing_utils import *
+
+import torch
+
+from .objaverse_part import ObjaversePartDataset, BatchedObjaversePartDataset
+
+# Copied from https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/loader.py
+class MultiEpochsDataLoader(torch.utils.data.DataLoader):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._DataLoader__initialized = False
+ if self.batch_sampler is None:
+ self.sampler = _RepeatSampler(self.sampler)
+ else:
+ self.batch_sampler = _RepeatSampler(self.batch_sampler)
+ self._DataLoader__initialized = True
+ self.iterator = super().__iter__()
+
+ def __len__(self):
+ return len(self.sampler) if self.batch_sampler is None else len(self.batch_sampler.sampler)
+
+ def __iter__(self):
+ for i in range(len(self)):
+ yield next(self.iterator)
+
+
+class _RepeatSampler(object):
+ """ Sampler that repeats forever.
+
+ Args:
+ sampler (Sampler)
+ """
+
+ def __init__(self, sampler):
+ self.sampler = sampler
+ if isinstance(self.sampler, torch.utils.data.sampler.BatchSampler):
+ self.batch_size = self.sampler.batch_size
+ self.drop_last = self.sampler.drop_last
+
+ def __len__(self):
+ return len(self.sampler)
+
+ def __iter__(self):
+ while True:
+ yield from iter(self.sampler)
+
+def yield_forever(iterator: Iterator[Any]):
+ while True:
+ for x in iterator:
+ yield x
\ No newline at end of file
diff --git a/src/datasets/objaverse_part.py b/src/datasets/objaverse_part.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cb15ee464820b84238356b5032f7559d6c3049e
--- /dev/null
+++ b/src/datasets/objaverse_part.py
@@ -0,0 +1,197 @@
+from src.utils.typing_utils import *
+
+import json
+import os
+import random
+
+import accelerate
+import torch
+from torchvision import transforms
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+
+from src.utils.data_utils import load_surface, load_surfaces
+
+class ObjaversePartDataset(torch.utils.data.Dataset):
+ def __init__(
+ self,
+ configs: DictConfig,
+ training: bool = True,
+ ):
+ super().__init__()
+ self.configs = configs
+ self.training = training
+
+ self.min_num_parts = configs['dataset']['min_num_parts']
+ self.max_num_parts = configs['dataset']['max_num_parts']
+ self.val_min_num_parts = configs['val']['min_num_parts']
+ self.val_max_num_parts = configs['val']['max_num_parts']
+
+ self.max_iou_mean = configs['dataset'].get('max_iou_mean', None)
+ self.max_iou_max = configs['dataset'].get('max_iou_max', None)
+
+ self.shuffle_parts = configs['dataset']['shuffle_parts']
+ self.training_ratio = configs['dataset']['training_ratio']
+ self.balance_object_and_parts = configs['dataset'].get('balance_object_and_parts', False)
+
+ self.rotating_ratio = configs['dataset'].get('rotating_ratio', 0.0)
+ self.rotating_degree = configs['dataset'].get('rotating_degree', 10.0)
+ self.transform = transforms.Compose([
+ transforms.RandomRotation(degrees=(-self.rotating_degree, self.rotating_degree), fill=(255, 255, 255)),
+ ])
+
+ if isinstance(configs['dataset']['config'], ListConfig):
+ data_configs = []
+ for config in configs['dataset']['config']:
+ local_data_configs = json.load(open(config))
+ if self.balance_object_and_parts:
+ if self.training:
+ local_data_configs = local_data_configs[:int(len(local_data_configs) * self.training_ratio)]
+ else:
+ local_data_configs = local_data_configs[int(len(local_data_configs) * self.training_ratio):]
+ local_data_configs = [config for config in local_data_configs if self.val_min_num_parts <= config['num_parts'] <= self.val_max_num_parts]
+ data_configs += local_data_configs
+ else:
+ data_configs = json.load(open(configs['dataset']['config']))
+ data_configs = [config for config in data_configs if config['valid']]
+ data_configs = [config for config in data_configs if self.min_num_parts <= config['num_parts'] <= self.max_num_parts]
+ if self.max_iou_mean is not None and self.max_iou_max is not None:
+ data_configs = [config for config in data_configs if config['iou_mean'] <= self.max_iou_mean]
+ data_configs = [config for config in data_configs if config['iou_max'] <= self.max_iou_max]
+ if not self.balance_object_and_parts:
+ if self.training:
+ data_configs = data_configs[:int(len(data_configs) * self.training_ratio)]
+ else:
+ data_configs = data_configs[int(len(data_configs) * self.training_ratio):]
+ data_configs = [config for config in data_configs if self.val_min_num_parts <= config['num_parts'] <= self.val_max_num_parts]
+ self.data_configs = data_configs
+ self.image_size = (512, 512)
+
+ def __len__(self) -> int:
+ return len(self.data_configs)
+
+ def _get_data_by_config(self, data_config):
+ if 'surface_path' in data_config:
+ surface_path = data_config['surface_path']
+ surface_data = np.load(surface_path, allow_pickle=True).item()
+ # If parts is empty, the object is the only part
+ part_surfaces = surface_data['parts'] if len(surface_data['parts']) > 0 else [surface_data['object']]
+ if self.shuffle_parts:
+ random.shuffle(part_surfaces)
+ part_surfaces = load_surfaces(part_surfaces) # [N, P, 6]
+ else:
+ part_surfaces = []
+ for surface_path in data_config['surface_paths']:
+ surface_data = np.load(surface_path, allow_pickle=True).item()
+ part_surfaces.append(load_surface(surface_data))
+ part_surfaces = torch.stack(part_surfaces, dim=0) # [N, P, 6]
+ image_path = data_config['image_path']
+ image = Image.open(image_path).resize(self.image_size)
+ if random.random() < self.rotating_ratio:
+ image = self.transform(image)
+ image = np.array(image)
+ image = torch.from_numpy(image).to(torch.uint8) # [H, W, 3]
+ images = torch.stack([image] * part_surfaces.shape[0], dim=0) # [N, H, W, 3]
+ return {
+ "images": images,
+ "part_surfaces": part_surfaces,
+ }
+
+ def __getitem__(self, idx: int):
+ # The dataset can only support batchsize == 1 training.
+ # Because the number of parts is not fixed.
+ # Please see BatchedObjaversePartDataset for batched training.
+ data_config = self.data_configs[idx]
+ data = self._get_data_by_config(data_config)
+ return data
+
+class BatchedObjaversePartDataset(ObjaversePartDataset):
+ def __init__(
+ self,
+ configs: DictConfig,
+ batch_size: int,
+ is_main_process: bool = False,
+ shuffle: bool = True,
+ training: bool = True,
+ ):
+ assert training
+ assert batch_size > 1
+ super().__init__(configs, training)
+ self.batch_size = batch_size
+ self.is_main_process = is_main_process
+ if batch_size < self.max_num_parts:
+ self.data_configs = [config for config in self.data_configs if config['num_parts'] <= batch_size]
+
+ if shuffle:
+ random.shuffle(self.data_configs)
+
+ self.object_configs = [config for config in self.data_configs if config['num_parts'] == 1]
+ self.parts_configs = [config for config in self.data_configs if config['num_parts'] > 1]
+
+ self.object_ratio = configs['dataset']['object_ratio']
+ # Here we keep the ratio of object to parts
+ self.object_configs = self.object_configs[:int(len(self.parts_configs) * self.object_ratio)]
+
+ dropped_data_configs = self.parts_configs + self.object_configs
+ if shuffle:
+ random.shuffle(dropped_data_configs)
+
+ self.data_configs = self._get_batched_configs(dropped_data_configs, batch_size)
+
+ def _get_batched_configs(self, data_configs, batch_size):
+ batched_data_configs = []
+ num_data_configs = len(data_configs)
+ progress_bar = tqdm(
+ range(len(data_configs)),
+ desc="Batching Dataset",
+ ncols=125,
+ disable=not self.is_main_process,
+ )
+ while len(data_configs) > 0:
+ temp_batch = []
+ temp_num_parts = 0
+ unchosen_configs = []
+ while temp_num_parts < batch_size and len(data_configs) > 0:
+ config = data_configs.pop() # pop the last config
+ num_parts = config['num_parts']
+ if temp_num_parts + num_parts <= batch_size:
+ temp_batch.append(config)
+ temp_num_parts += num_parts
+ progress_bar.update(1)
+ else:
+ unchosen_configs.append(config) # add back to the end
+ data_configs = data_configs + unchosen_configs # concat the unchosen configs
+ if temp_num_parts == batch_size:
+ # Successfully get a batch
+ if len(temp_batch) < batch_size:
+ # pad the batch
+ temp_batch += [{}] * (batch_size - len(temp_batch))
+ batched_data_configs += temp_batch
+ # Else, the code enters here because len(data_configs) == 0
+ # which means in the left data_configs, there are no enough
+ # "suitable" configs to form a batch.
+ # Thus, drop the uncompleted batch.
+ progress_bar.close()
+ return batched_data_configs
+
+ def __getitem__(self, idx: int):
+ data_config = self.data_configs[idx]
+ if len(data_config) == 0:
+ # placeholder
+ return {}
+ data = self._get_data_by_config(data_config)
+ return data
+
+ def collate_fn(self, batch):
+ batch = [data for data in batch if len(data) > 0]
+ images = torch.cat([data['images'] for data in batch], dim=0) # [N, H, W, 3]
+ surfaces = torch.cat([data['part_surfaces'] for data in batch], dim=0) # [N, P, 6]
+ num_parts = torch.LongTensor([data['part_surfaces'].shape[0] for data in batch])
+ assert images.shape[0] == surfaces.shape[0] == num_parts.sum() == self.batch_size
+ batch = {
+ "images": images,
+ "part_surfaces": surfaces,
+ "num_parts": num_parts,
+ }
+ return batch
\ No newline at end of file
diff --git a/src/models/attention_processor.py b/src/models/attention_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a237bfd9579727e402ef38a664122c19cd68c1b
--- /dev/null
+++ b/src/models/attention_processor.py
@@ -0,0 +1,624 @@
+from typing import Callable, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from diffusers.models.attention_processor import Attention
+from diffusers.utils import logging
+from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
+from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph
+from einops import rearrange
+from torch import nn
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+class FlashTripo2AttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
+ used in the Tripo2DiT model. It applies a s normalization layer and rotary embedding on query and key vector.
+ """
+
+ def __init__(self, topk=True):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+ self.topk = topk
+
+ def qkv(self, attn, q, k, v, attn_mask, dropout_p, is_causal):
+ if k.shape[-2] == 3072:
+ topk = 1024
+ elif k.shape[-2] == 512:
+ topk = 256
+ else:
+ topk = k.shape[-2] // 3
+
+ if self.topk is True:
+ q1 = q[:, :, ::100, :]
+ sim = q1 @ k.transpose(-1, -2)
+ sim = torch.mean(sim, -2)
+ topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
+ topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
+ v0 = torch.gather(v, dim=-2, index=topk_ind)
+ k0 = torch.gather(k, dim=-2, index=topk_ind)
+ out = F.scaled_dot_product_attention(q, k0, v0)
+ elif self.topk is False:
+ out = F.scaled_dot_product_attention(q, k, v)
+ else:
+ idx, counts = self.topk
+ start = 0
+ outs = []
+ for grid_coord, count in zip(idx, counts):
+ end = start + count
+ q_chunk = q[:, :, start:end, :]
+ q1 = q_chunk[:, :, ::50, :]
+ sim = q1 @ k.transpose(-1, -2)
+ sim = torch.mean(sim, -2)
+ topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
+ topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
+ v0 = torch.gather(v, dim=-2, index=topk_ind)
+ k0 = torch.gather(k, dim=-2, index=topk_ind)
+ out = F.scaled_dot_product_attention(q_chunk, k0, v0)
+ outs.append(out)
+ start += count
+ out = torch.cat(outs, dim=-2)
+ self.topk = False
+ return out
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ from diffusers.models.embeddings import apply_rotary_emb
+
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(
+ batch_size, channel, height * width
+ ).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape
+ if encoder_hidden_states is None
+ else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(
+ attention_mask, sequence_length, batch_size
+ )
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
+ )
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
+ 1, 2
+ )
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
+ encoder_hidden_states
+ )
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ # NOTE that tripo2 split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
+ # instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
+ if not attn.is_cross_attention:
+ qkv = torch.cat((query, key, value), dim=-1)
+ split_size = qkv.shape[-1] // attn.heads // 3
+ qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+ else:
+ kv = torch.cat((key, value), dim=-1)
+ split_size = kv.shape[-1] // attn.heads // 2
+ kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
+ key, value = torch.split(kv, split_size, dim=-1)
+
+ head_dim = key.shape[-1]
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb)
+ if not attn.is_cross_attention:
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ # flashvdm topk
+ hidden_states = self.qkv(attn, query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(
+ batch_size, -1, attn.heads * head_dim
+ )
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
+ batch_size, channel, height, width
+ )
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+class TripoSGAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
+ used in the TripoSG model. It applies a s normalization layer and rotary embedding on query and key vector.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ from diffusers.models.embeddings import apply_rotary_emb
+
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(
+ batch_size, channel, height * width
+ ).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape
+ if encoder_hidden_states is None
+ else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(
+ attention_mask, sequence_length, batch_size
+ )
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
+ )
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
+ 1, 2
+ )
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
+ encoder_hidden_states
+ )
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ # NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
+ # instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
+ if not attn.is_cross_attention:
+ qkv = torch.cat((query, key, value), dim=-1)
+ split_size = qkv.shape[-1] // attn.heads // 3
+ qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+ else:
+ kv = torch.cat((key, value), dim=-1)
+ split_size = kv.shape[-1] // attn.heads // 2
+ kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
+ key, value = torch.split(kv, split_size, dim=-1)
+
+ head_dim = key.shape[-1]
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb)
+ if not attn.is_cross_attention:
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(
+ batch_size, -1, attn.heads * head_dim
+ )
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
+ batch_size, channel, height, width
+ )
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class FusedTripoSGAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused
+ projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on
+ query and key vector.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "FusedTripoSGAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ from diffusers.models.embeddings import apply_rotary_emb
+
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(
+ batch_size, channel, height * width
+ ).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape
+ if encoder_hidden_states is None
+ else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(
+ attention_mask, sequence_length, batch_size
+ )
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
+ )
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
+ 1, 2
+ )
+
+ # NOTE that pre-trained split heads first, then split qkv
+ if encoder_hidden_states is None:
+ qkv = attn.to_qkv(hidden_states)
+ split_size = qkv.shape[-1] // attn.heads // 3
+ qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+ else:
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
+ encoder_hidden_states
+ )
+ query = attn.to_q(hidden_states)
+
+ kv = attn.to_kv(encoder_hidden_states)
+ split_size = kv.shape[-1] // attn.heads // 2
+ kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
+ key, value = torch.split(kv, split_size, dim=-1)
+
+ head_dim = key.shape[-1]
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb)
+ if not attn.is_cross_attention:
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(
+ batch_size, -1, attn.heads * head_dim
+ )
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
+ batch_size, channel, height, width
+ )
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+# Modified from https://github.com/VAST-AI-Research/MIDI-3D/blob/main/midi/models/attention_processor.py#L264
+class PartCrafterAttnProcessor:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
+ used in the PartCrafter model. It applies a normalization layer and rotary embedding on query and key vector.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ num_parts: Optional[Union[int, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ from diffusers.models.embeddings import apply_rotary_emb
+
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(
+ batch_size, channel, height * width
+ ).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape
+ if encoder_hidden_states is None
+ else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(
+ attention_mask, sequence_length, batch_size
+ )
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
+ )
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
+ 1, 2
+ )
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
+ encoder_hidden_states
+ )
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ # NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
+ # instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
+ if not attn.is_cross_attention:
+ qkv = torch.cat((query, key, value), dim=-1)
+ split_size = qkv.shape[-1] // attn.heads // 3
+ qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+ else:
+ kv = torch.cat((key, value), dim=-1)
+ split_size = kv.shape[-1] // attn.heads // 2
+ kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
+ key, value = torch.split(kv, split_size, dim=-1)
+
+ head_dim = key.shape[-1]
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb)
+ if not attn.is_cross_attention:
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ if isinstance(num_parts, torch.Tensor):
+ # Assume list in training, do not consider classifier-free guidance
+ idx = 0
+ hidden_states_list = []
+ for n_p in num_parts:
+ k = key[idx : idx + n_p]
+ v = value[idx : idx + n_p]
+ q = query[idx : idx + n_p]
+ idx += n_p
+ if k.shape[2] == q.shape[2]:
+ # Assuming self-attention
+ # Here 'b' is always 1
+ k = rearrange(
+ k, "(b ni) h nt c -> b h (ni nt) c", ni=n_p
+ ) # [b, h, ni*nt, c]
+ v = rearrange(
+ v, "(b ni) h nt c -> b h (ni nt) c", ni=n_p
+ ) # [b, h, ni*nt, c]
+ else:
+ # Assuming cross-attention
+ # Here 'b' is always 1
+ k = k[::n_p] # [b, h, nt, c]
+ v = v[::n_p] # [b, h, nt, c]
+ # Here 'b' is always 1
+ q = rearrange(
+ q, "(b ni) h nt c -> b h (ni nt) c", ni=n_p
+ ) # [b, h, ni*nt, c]
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ h_s = F.scaled_dot_product_attention(
+ q, k, v,
+ dropout_p=0.0,
+ is_causal=False,
+ )
+ h_s = h_s.transpose(1, 2).reshape(
+ n_p, -1, attn.heads * head_dim
+ )
+ h_s = h_s.to(query.dtype)
+ hidden_states_list.append(h_s)
+ hidden_states = torch.cat(hidden_states_list, dim=0)
+
+ elif isinstance(num_parts, int):
+ # Assume single instance
+ if key.shape[2] == query.shape[2]:
+ # Assuming self-attention
+ # Here we need 'b' when using classifier-free guidance
+ key = rearrange(
+ key, "(b ni) h nt c -> b h (ni nt) c", ni=num_parts
+ ) # [b, h, ni*nt, c]
+ value = rearrange(
+ value, "(b ni) h nt c -> b h (ni nt) c", ni=num_parts
+ ) # [b, h, ni*nt, c]
+ else:
+ # Assuming cross-attention
+ # Here we need 'b' when using classifier-free guidance
+ # Control signal is repeated ni times within each (b, ni)
+ # We select only the first instance per group
+ key = key[::num_parts] # [b, h, nt, c]
+ value = value[::num_parts] # [b, h, nt, c]
+ query = rearrange(
+ query, "(b ni) h nt c -> b h (ni nt) c", ni=num_parts
+ ) # [b, h, ni*nt, c]
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ hidden_states = F.scaled_dot_product_attention(
+ query,
+ key,
+ value,
+ dropout_p=0.0,
+ is_causal=False,
+ )
+ hidden_states = hidden_states.transpose(1, 2).reshape(
+ batch_size, -1, attn.heads * head_dim
+ )
+ hidden_states = hidden_states.to(query.dtype)
+
+ else:
+ raise ValueError(
+ "num_parts must be a torch.Tensor or int, but got {}".format(type(num_parts))
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
+ batch_size, channel, height, width
+ )
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
diff --git a/src/models/autoencoders/__init__.py b/src/models/autoencoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..64a9ea290fb3c006737000b0e046ac492e0ec26f
--- /dev/null
+++ b/src/models/autoencoders/__init__.py
@@ -0,0 +1 @@
+from .autoencoder_kl_triposg import TripoSGVAEModel
diff --git a/src/models/autoencoders/autoencoder_kl_triposg.py b/src/models/autoencoders/autoencoder_kl_triposg.py
new file mode 100644
index 0000000000000000000000000000000000000000..6aafff9597098d3f917c329fb389aef8a7e95260
--- /dev/null
+++ b/src/models/autoencoders/autoencoder_kl_triposg.py
@@ -0,0 +1,536 @@
+from typing import Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.attention_processor import Attention, AttentionProcessor
+from diffusers.models.autoencoders.vae import DecoderOutput
+from diffusers.models.modeling_outputs import AutoencoderKLOutput
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.normalization import FP32LayerNorm, LayerNorm
+from diffusers.utils import logging
+from diffusers.utils.accelerate_utils import apply_forward_hook
+from einops import repeat
+from torch_cluster import fps
+from tqdm import tqdm
+
+from ..attention_processor import FusedTripoSGAttnProcessor2_0, TripoSGAttnProcessor2_0, FlashTripo2AttnProcessor2_0
+from ..embeddings import FrequencyPositionalEmbedding
+from ..transformers.partcrafter_transformer import DiTBlock
+from .vae import DiagonalGaussianDistribution
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class TripoSGEncoder(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 3,
+ dim: int = 512,
+ num_attention_heads: int = 8,
+ num_layers: int = 8,
+ ):
+ super().__init__()
+
+ self.proj_in = nn.Linear(in_channels, dim, bias=True)
+
+ self.blocks = nn.ModuleList(
+ [
+ DiTBlock(
+ dim=dim,
+ num_attention_heads=num_attention_heads,
+ use_self_attention=False,
+ use_cross_attention=True,
+ cross_attention_dim=dim,
+ cross_attention_norm_type="layer_norm",
+ activation_fn="gelu",
+ norm_type="fp32_layer_norm",
+ norm_eps=1e-5,
+ qk_norm=False,
+ qkv_bias=False,
+ ) # cross attention
+ ]
+ + [
+ DiTBlock(
+ dim=dim,
+ num_attention_heads=num_attention_heads,
+ use_self_attention=True,
+ self_attention_norm_type="fp32_layer_norm",
+ use_cross_attention=False,
+ activation_fn="gelu",
+ norm_type="fp32_layer_norm",
+ norm_eps=1e-5,
+ qk_norm=False,
+ qkv_bias=False,
+ )
+ for _ in range(num_layers) # self attention
+ ]
+ )
+
+ self.norm_out = LayerNorm(dim)
+
+ def forward(self, sample_1: torch.Tensor, sample_2: torch.Tensor):
+ hidden_states = self.proj_in(sample_1)
+ encoder_hidden_states = self.proj_in(sample_2)
+
+ for layer, block in enumerate(self.blocks):
+ if layer == 0:
+ hidden_states = block(
+ hidden_states, encoder_hidden_states=encoder_hidden_states
+ )
+ else:
+ hidden_states = block(hidden_states)
+
+ hidden_states = self.norm_out(hidden_states)
+
+ return hidden_states
+
+
+class TripoSGDecoder(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 1,
+ dim: int = 512,
+ num_attention_heads: int = 8,
+ num_layers: int = 16,
+ grad_type: str = "analytical",
+ grad_interval: float = 0.001,
+ ):
+ super().__init__()
+
+ if grad_type not in ["numerical", "analytical"]:
+ raise ValueError(f"grad_type must be one of ['numerical', 'analytical']")
+ self.grad_type = grad_type
+ self.grad_interval = grad_interval
+
+ self.blocks = nn.ModuleList(
+ [
+ DiTBlock(
+ dim=dim,
+ num_attention_heads=num_attention_heads,
+ use_self_attention=True,
+ self_attention_norm_type="fp32_layer_norm",
+ use_cross_attention=False,
+ activation_fn="gelu",
+ norm_type="fp32_layer_norm",
+ norm_eps=1e-5,
+ qk_norm=False,
+ qkv_bias=False,
+ )
+ for _ in range(num_layers) # self attention
+ ]
+ + [
+ DiTBlock(
+ dim=dim,
+ num_attention_heads=num_attention_heads,
+ use_self_attention=False,
+ use_cross_attention=True,
+ cross_attention_dim=dim,
+ cross_attention_norm_type="layer_norm",
+ activation_fn="gelu",
+ norm_type="fp32_layer_norm",
+ norm_eps=1e-5,
+ qk_norm=False,
+ qkv_bias=False,
+ ) # cross attention
+ ]
+ )
+
+ self.proj_query = nn.Linear(in_channels, dim, bias=True)
+
+ self.norm_out = LayerNorm(dim)
+ self.proj_out = nn.Linear(dim, out_channels, bias=True)
+
+ def set_topk(self, topk):
+ self.blocks[-1].set_topk(topk)
+
+ def set_flash_processor(self, processor):
+ self.blocks[-1].set_flash_processor(processor)
+
+ def query_geometry(
+ self,
+ model_fn: callable,
+ queries: torch.Tensor,
+ sample: torch.Tensor,
+ grad: bool = False,
+ ):
+ logits = model_fn(queries, sample)
+ if grad:
+ with torch.autocast(device_type="cuda", dtype=torch.float32):
+ if self.grad_type == "numerical":
+ interval = self.grad_interval
+ grad_value = []
+ for offset in [
+ (interval, 0, 0),
+ (0, interval, 0),
+ (0, 0, interval),
+ ]:
+ offset_tensor = torch.tensor(offset, device=queries.device)[
+ None, :
+ ]
+ res_p = model_fn(queries + offset_tensor, sample)[..., 0]
+ res_n = model_fn(queries - offset_tensor, sample)[..., 0]
+ grad_value.append((res_p - res_n) / (2 * interval))
+ grad_value = torch.stack(grad_value, dim=-1)
+ else:
+ queries_d = torch.clone(queries)
+ queries_d.requires_grad = True
+ with torch.enable_grad():
+ res_d = model_fn(queries_d, sample)
+ grad_value = torch.autograd.grad(
+ res_d,
+ [queries_d],
+ grad_outputs=torch.ones_like(res_d),
+ create_graph=self.training,
+ )[0]
+ else:
+ grad_value = None
+
+ return logits, grad_value
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ queries: torch.Tensor,
+ kv_cache: Optional[torch.Tensor] = None,
+ ):
+ if kv_cache is None:
+ hidden_states = sample
+ for _, block in enumerate(self.blocks[:-1]):
+ hidden_states = block(hidden_states)
+ kv_cache = hidden_states
+
+ # query grid logits by cross attention
+ def query_fn(q, kv):
+ q = self.proj_query(q)
+ l = self.blocks[-1](q, encoder_hidden_states=kv)
+ return self.proj_out(self.norm_out(l))
+
+ logits, grad = self.query_geometry(
+ query_fn, queries, kv_cache, grad=self.training
+ )
+ logits = logits * -1 if not isinstance(logits, Tuple) else logits[0] * -1
+
+ return logits, kv_cache
+
+
+class TripoSGVAEModel(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3, # NOTE xyz instead of feature dim
+ latent_channels: int = 64,
+ num_attention_heads: int = 8,
+ width_encoder: int = 512,
+ width_decoder: int = 1024,
+ num_layers_encoder: int = 8,
+ num_layers_decoder: int = 16,
+ embedding_type: str = "frequency",
+ embed_frequency: int = 8,
+ embed_include_pi: bool = False,
+ ):
+ super().__init__()
+
+ self.out_channels = 1
+
+ if embedding_type == "frequency":
+ self.embedder = FrequencyPositionalEmbedding(
+ num_freqs=embed_frequency,
+ logspace=True,
+ input_dim=in_channels,
+ include_pi=embed_include_pi,
+ )
+ else:
+ raise NotImplementedError(
+ f"Embedding type {embedding_type} is not supported."
+ )
+
+ self.encoder = TripoSGEncoder(
+ in_channels=in_channels + self.embedder.out_dim,
+ dim=width_encoder,
+ num_attention_heads=num_attention_heads,
+ num_layers=num_layers_encoder,
+ )
+ self.decoder = TripoSGDecoder(
+ in_channels=self.embedder.out_dim,
+ out_channels=self.out_channels,
+ dim=width_decoder,
+ num_attention_heads=num_attention_heads,
+ num_layers=num_layers_decoder,
+ )
+
+ self.quant = nn.Linear(width_encoder, latent_channels * 2, bias=True)
+ self.post_quant = nn.Linear(latent_channels, width_decoder, bias=True)
+
+ self.use_slicing = False
+ self.slicing_length = 1
+
+ def set_flash_decoder(self):
+ self.decoder.set_flash_processor(FlashTripo2AttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedTripoSGAttnProcessor2_0
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError(
+ "`fuse_qkv_projections()` is not supported for models having added KV projections."
+ )
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ self.set_attn_processor(FusedTripoSGAttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(
+ name: str,
+ module: torch.nn.Module,
+ processors: Dict[str, AttentionProcessor],
+ ):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
+ ):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ self.set_attn_processor(TripoSGAttnProcessor2_0())
+
+ def enable_slicing(self, slicing_length: int = 1) -> None:
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+ self.slicing_length = slicing_length
+
+ def disable_slicing(self) -> None:
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ def _sample_features(
+ self, x: torch.Tensor, num_tokens: int = 2048, seed: Optional[int] = None
+ ):
+ """
+ Sample points from features of the input point cloud.
+
+ Args:
+ x (torch.Tensor): The input point cloud. shape: (B, N, C)
+ num_tokens (int, optional): The number of points to sample. Defaults to 2048.
+ seed (Optional[int], optional): The random seed. Defaults to None.
+ """
+ rng = np.random.default_rng(seed)
+ indices = rng.choice(
+ x.shape[1], num_tokens * 4, replace=num_tokens * 4 > x.shape[1]
+ )
+ selected_points = x[:, indices]
+
+ batch_size, num_points, num_channels = selected_points.shape
+ flattened_points = selected_points.view(batch_size * num_points, num_channels)
+ batch_indices = (
+ torch.arange(batch_size).to(x.device).repeat_interleave(num_points)
+ )
+
+ # fps sampling
+ sampling_ratio = 1.0 / 4
+ sampled_indices = fps(
+ flattened_points[:, :3],
+ batch_indices,
+ ratio=sampling_ratio,
+ random_start=self.training,
+ )
+ sampled_points = flattened_points[sampled_indices].view(
+ batch_size, -1, num_channels
+ )
+
+ return sampled_points
+
+ def _encode(
+ self, x: torch.Tensor, num_tokens: int = 2048, seed: Optional[int] = None
+ ):
+ position_channels = self.config.in_channels
+ positions, features = x[..., :position_channels], x[..., position_channels:]
+ x_kv = torch.cat([self.embedder(positions), features], dim=-1)
+
+ sampled_x = self._sample_features(x, num_tokens, seed)
+ positions, features = (
+ sampled_x[..., :position_channels],
+ sampled_x[..., position_channels:],
+ )
+ x_q = torch.cat([self.embedder(positions), features], dim=-1)
+
+ x = self.encoder(x_q, x_kv)
+
+ x = self.quant(x)
+
+ return x
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.Tensor, return_dict: bool = True, **kwargs
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ """
+ Encode a batch of point features into latents.
+ """
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [
+ self._encode(x_slice, **kwargs)
+ for x_slice in x.split(self.slicing_length)
+ ]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self._encode(x, **kwargs)
+
+ posterior = DiagonalGaussianDistribution(h, feature_dim=-1)
+
+ if not return_dict:
+ return (posterior,)
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(
+ self,
+ z: torch.Tensor,
+ sampled_points: torch.Tensor,
+ num_chunks: int = 50000,
+ to_cpu: bool = False,
+ return_dict: bool = True,
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ xyz_samples = sampled_points
+
+ z = self.post_quant(z)
+
+ num_points = xyz_samples.shape[1]
+ kv_cache = None
+ dec = []
+
+ for i in range(0, num_points, num_chunks):
+ queries = xyz_samples[:, i : i + num_chunks, :].to(z.device, dtype=z.dtype)
+ queries = self.embedder(queries)
+
+ z_, kv_cache = self.decoder(z, queries, kv_cache)
+ dec.append(z_ if not to_cpu else z_.cpu())
+
+ z = torch.cat(dec, dim=1)
+
+ if not return_dict:
+ return (z,)
+
+ return DecoderOutput(sample=z)
+
+ @apply_forward_hook
+ def decode(
+ self,
+ z: torch.Tensor,
+ sampled_points: torch.Tensor,
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [
+ self._decode(z_slice, p_slice, **kwargs).sample
+ for z_slice, p_slice in zip(
+ z.split(self.slicing_length),
+ sampled_points.split(self.slicing_length),
+ )
+ ]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z, sampled_points, **kwargs).sample
+
+ if not return_dict:
+ return (decoded,)
+ return DecoderOutput(sample=decoded)
+
+ def forward(self, x: torch.Tensor):
+ pass
diff --git a/src/models/autoencoders/vae.py b/src/models/autoencoders/vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..01aae603a3f1b298bacd723053d396c614028018
--- /dev/null
+++ b/src/models/autoencoders/vae.py
@@ -0,0 +1,69 @@
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+from diffusers.utils.torch_utils import randn_tensor
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(
+ self,
+ parameters: torch.Tensor,
+ deterministic: bool = False,
+ feature_dim: int = 1,
+ ):
+ self.parameters = parameters
+ self.feature_dim = feature_dim
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=feature_dim)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
+ )
+
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
+ # make sure sample is on the same device as the parameters and has same dtype
+ sample = randn_tensor(
+ self.mean.shape,
+ generator=generator,
+ device=self.parameters.device,
+ dtype=self.parameters.dtype,
+ )
+ x = self.mean + self.std * sample
+ return x
+
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3],
+ )
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var
+ - 1.0
+ - self.logvar
+ + other.logvar,
+ dim=[1, 2, 3],
+ )
+
+ def nll(
+ self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]
+ ) -> torch.Tensor:
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims,
+ )
+
+ def mode(self) -> torch.Tensor:
+ return self.mean
diff --git a/src/models/briarmbg.py b/src/models/briarmbg.py
new file mode 100644
index 0000000000000000000000000000000000000000..fadd4514f613135a60fca1cf263fe6d49cf9d454
--- /dev/null
+++ b/src/models/briarmbg.py
@@ -0,0 +1,464 @@
+"""
+Source and Copyright Notice:
+This code is from briaai/RMBG-1.4
+Original repository: https://huggingface.co/briaai/RMBG-1.4
+Copyright belongs to briaai
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from huggingface_hub import PyTorchModelHubMixin
+
+class REBNCONV(nn.Module):
+ def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
+ super(REBNCONV,self).__init__()
+
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
+ self.relu_s1 = nn.ReLU(inplace=True)
+
+ def forward(self,x):
+
+ hx = x
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
+
+ return xout
+
+## upsample tensor 'src' to have the same spatial size with tensor 'tar'
+def _upsample_like(src,tar):
+
+ src = F.interpolate(src,size=tar.shape[2:],mode='bilinear')
+
+ return src
+
+
+### RSU-7 ###
+class RSU7(nn.Module):
+
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
+ super(RSU7,self).__init__()
+
+ self.in_ch = in_ch
+ self.mid_ch = mid_ch
+ self.out_ch = out_ch
+
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
+
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
+
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
+
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
+
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
+
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
+
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
+
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
+
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
+
+ def forward(self,x):
+ b, c, h, w = x.shape
+
+ hx = x
+ hxin = self.rebnconvin(hx)
+
+ hx1 = self.rebnconv1(hxin)
+ hx = self.pool1(hx1)
+
+ hx2 = self.rebnconv2(hx)
+ hx = self.pool2(hx2)
+
+ hx3 = self.rebnconv3(hx)
+ hx = self.pool3(hx3)
+
+ hx4 = self.rebnconv4(hx)
+ hx = self.pool4(hx4)
+
+ hx5 = self.rebnconv5(hx)
+ hx = self.pool5(hx5)
+
+ hx6 = self.rebnconv6(hx)
+
+ hx7 = self.rebnconv7(hx6)
+
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
+ hx6dup = _upsample_like(hx6d,hx5)
+
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
+ hx5dup = _upsample_like(hx5d,hx4)
+
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
+ hx4dup = _upsample_like(hx4d,hx3)
+
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
+ hx3dup = _upsample_like(hx3d,hx2)
+
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
+ hx2dup = _upsample_like(hx2d,hx1)
+
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
+
+ return hx1d + hxin
+
+
+### RSU-6 ###
+class RSU6(nn.Module):
+
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+ super(RSU6,self).__init__()
+
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
+
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
+
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
+
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
+
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
+
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
+
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
+
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
+
+ def forward(self,x):
+
+ hx = x
+
+ hxin = self.rebnconvin(hx)
+
+ hx1 = self.rebnconv1(hxin)
+ hx = self.pool1(hx1)
+
+ hx2 = self.rebnconv2(hx)
+ hx = self.pool2(hx2)
+
+ hx3 = self.rebnconv3(hx)
+ hx = self.pool3(hx3)
+
+ hx4 = self.rebnconv4(hx)
+ hx = self.pool4(hx4)
+
+ hx5 = self.rebnconv5(hx)
+
+ hx6 = self.rebnconv6(hx5)
+
+
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
+ hx5dup = _upsample_like(hx5d,hx4)
+
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
+ hx4dup = _upsample_like(hx4d,hx3)
+
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
+ hx3dup = _upsample_like(hx3d,hx2)
+
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
+ hx2dup = _upsample_like(hx2d,hx1)
+
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
+
+ return hx1d + hxin
+
+### RSU-5 ###
+class RSU5(nn.Module):
+
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+ super(RSU5,self).__init__()
+
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
+
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
+
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
+
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
+
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
+
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
+
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
+
+ def forward(self,x):
+
+ hx = x
+
+ hxin = self.rebnconvin(hx)
+
+ hx1 = self.rebnconv1(hxin)
+ hx = self.pool1(hx1)
+
+ hx2 = self.rebnconv2(hx)
+ hx = self.pool2(hx2)
+
+ hx3 = self.rebnconv3(hx)
+ hx = self.pool3(hx3)
+
+ hx4 = self.rebnconv4(hx)
+
+ hx5 = self.rebnconv5(hx4)
+
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
+ hx4dup = _upsample_like(hx4d,hx3)
+
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
+ hx3dup = _upsample_like(hx3d,hx2)
+
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
+ hx2dup = _upsample_like(hx2d,hx1)
+
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
+
+ return hx1d + hxin
+
+### RSU-4 ###
+class RSU4(nn.Module):
+
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+ super(RSU4,self).__init__()
+
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
+
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
+
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
+
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
+
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
+
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
+
+ def forward(self,x):
+
+ hx = x
+
+ hxin = self.rebnconvin(hx)
+
+ hx1 = self.rebnconv1(hxin)
+ hx = self.pool1(hx1)
+
+ hx2 = self.rebnconv2(hx)
+ hx = self.pool2(hx2)
+
+ hx3 = self.rebnconv3(hx)
+
+ hx4 = self.rebnconv4(hx3)
+
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
+ hx3dup = _upsample_like(hx3d,hx2)
+
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
+ hx2dup = _upsample_like(hx2d,hx1)
+
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
+
+ return hx1d + hxin
+
+### RSU-4F ###
+class RSU4F(nn.Module):
+
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+ super(RSU4F,self).__init__()
+
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
+
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
+
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
+
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
+
+ def forward(self,x):
+
+ hx = x
+
+ hxin = self.rebnconvin(hx)
+
+ hx1 = self.rebnconv1(hxin)
+ hx2 = self.rebnconv2(hx1)
+ hx3 = self.rebnconv3(hx2)
+
+ hx4 = self.rebnconv4(hx3)
+
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
+
+ return hx1d + hxin
+
+
+class myrebnconv(nn.Module):
+ def __init__(self, in_ch=3,
+ out_ch=1,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ dilation=1,
+ groups=1):
+ super(myrebnconv,self).__init__()
+
+ self.conv = nn.Conv2d(in_ch,
+ out_ch,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups)
+ self.bn = nn.BatchNorm2d(out_ch)
+ self.rl = nn.ReLU(inplace=True)
+
+ def forward(self,x):
+ return self.rl(self.bn(self.conv(x)))
+
+
+class BriaRMBG(nn.Module, PyTorchModelHubMixin):
+
+ def __init__(self,config:dict={"in_ch":3,"out_ch":1}):
+ super(BriaRMBG,self).__init__()
+ in_ch=config["in_ch"]
+ out_ch=config["out_ch"]
+ self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
+ self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
+
+ self.stage1 = RSU7(64,32,64)
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
+
+ self.stage2 = RSU6(64,32,128)
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
+
+ self.stage3 = RSU5(128,64,256)
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
+
+ self.stage4 = RSU4(256,128,512)
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
+
+ self.stage5 = RSU4F(512,256,512)
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
+
+ self.stage6 = RSU4F(512,256,512)
+
+ # decoder
+ self.stage5d = RSU4F(1024,256,512)
+ self.stage4d = RSU4(1024,128,256)
+ self.stage3d = RSU5(512,64,128)
+ self.stage2d = RSU6(256,32,64)
+ self.stage1d = RSU7(128,16,64)
+
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
+
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
+
+ def forward(self,x):
+
+ hx = x
+
+ hxin = self.conv_in(hx)
+ #hx = self.pool_in(hxin)
+
+ #stage 1
+ hx1 = self.stage1(hxin)
+ hx = self.pool12(hx1)
+
+ #stage 2
+ hx2 = self.stage2(hx)
+ hx = self.pool23(hx2)
+
+ #stage 3
+ hx3 = self.stage3(hx)
+ hx = self.pool34(hx3)
+
+ #stage 4
+ hx4 = self.stage4(hx)
+ hx = self.pool45(hx4)
+
+ #stage 5
+ hx5 = self.stage5(hx)
+ hx = self.pool56(hx5)
+
+ #stage 6
+ hx6 = self.stage6(hx)
+ hx6up = _upsample_like(hx6,hx5)
+
+ #-------------------- decoder --------------------
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
+ hx5dup = _upsample_like(hx5d,hx4)
+
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
+ hx4dup = _upsample_like(hx4d,hx3)
+
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
+ hx3dup = _upsample_like(hx3d,hx2)
+
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
+ hx2dup = _upsample_like(hx2d,hx1)
+
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
+
+
+ #side output
+ d1 = self.side1(hx1d)
+ d1 = _upsample_like(d1,x)
+
+ d2 = self.side2(hx2d)
+ d2 = _upsample_like(d2,x)
+
+ d3 = self.side3(hx3d)
+ d3 = _upsample_like(d3,x)
+
+ d4 = self.side4(hx4d)
+ d4 = _upsample_like(d4,x)
+
+ d5 = self.side5(hx5d)
+ d5 = _upsample_like(d5,x)
+
+ d6 = self.side6(hx6)
+ d6 = _upsample_like(d6,x)
+
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]
+
diff --git a/src/models/embeddings.py b/src/models/embeddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..1fdb5eaa8570c87d2ffbfc7eeeaba8b070874a69
--- /dev/null
+++ b/src/models/embeddings.py
@@ -0,0 +1,96 @@
+import torch
+import torch.nn as nn
+
+
+class FrequencyPositionalEmbedding(nn.Module):
+ """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
+ each feature dimension of `x[..., i]` into:
+ [
+ sin(x[..., i]),
+ sin(f_1*x[..., i]),
+ sin(f_2*x[..., i]),
+ ...
+ sin(f_N * x[..., i]),
+ cos(x[..., i]),
+ cos(f_1*x[..., i]),
+ cos(f_2*x[..., i]),
+ ...
+ cos(f_N * x[..., i]),
+ x[..., i] # only present if include_input is True.
+ ], here f_i is the frequency.
+
+ Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
+ If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
+ Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
+
+ Args:
+ num_freqs (int): the number of frequencies, default is 6;
+ logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
+ input_dim (int): the input dimension, default is 3;
+ include_input (bool): include the input tensor or not, default is True.
+
+ Attributes:
+ frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
+
+ out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
+ otherwise, it is input_dim * num_freqs * 2.
+
+ """
+
+ def __init__(
+ self,
+ num_freqs: int = 6,
+ logspace: bool = True,
+ input_dim: int = 3,
+ include_input: bool = True,
+ include_pi: bool = True,
+ ) -> None:
+ """The initialization"""
+
+ super().__init__()
+
+ if logspace:
+ frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32)
+ else:
+ frequencies = torch.linspace(
+ 1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32
+ )
+
+ if include_pi:
+ frequencies *= torch.pi
+
+ self.register_buffer("frequencies", frequencies, persistent=False)
+ self.include_input = include_input
+ self.num_freqs = num_freqs
+
+ self.out_dim = self.get_dims(input_dim)
+
+ def get_dims(self, input_dim):
+ temp = 1 if self.include_input or self.num_freqs == 0 else 0
+ out_dim = input_dim * (self.num_freqs * 2 + temp)
+
+ return out_dim
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward process.
+
+ Args:
+ x: tensor of shape [..., dim]
+
+ Returns:
+ embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
+ where temp is 1 if include_input is True and 0 otherwise.
+ """
+
+ if self.num_freqs > 0:
+ embed = (x[..., None].contiguous() * self.frequencies).view(
+ *x.shape[:-1], -1
+ )
+ if self.include_input:
+ return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
+ else:
+ return torch.cat((embed.sin(), embed.cos()), dim=-1)
+ else:
+ return x
diff --git a/src/models/transformers/__init__.py b/src/models/transformers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..eadc25c031cae71652468cb7101777591da53683
--- /dev/null
+++ b/src/models/transformers/__init__.py
@@ -0,0 +1,3 @@
+from typing import Callable, Optional
+
+from .partcrafter_transformer import PartCrafterDiTModel
\ No newline at end of file
diff --git a/src/models/transformers/modeling_outputs.py b/src/models/transformers/modeling_outputs.py
new file mode 100644
index 0000000000000000000000000000000000000000..0928fa0ca39275a85f8b7fa49c68af745d4c74c5
--- /dev/null
+++ b/src/models/transformers/modeling_outputs.py
@@ -0,0 +1,8 @@
+from dataclasses import dataclass
+
+import torch
+
+
+@dataclass
+class Transformer1DModelOutput:
+ sample: torch.FloatTensor
diff --git a/src/models/transformers/partcrafter_transformer.py b/src/models/transformers/partcrafter_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f499f2ab82377f8704f3f105f68bb2fb500d5110
--- /dev/null
+++ b/src/models/transformers/partcrafter_transformer.py
@@ -0,0 +1,813 @@
+# Copyright (c) 2025 Yuchen Lin
+
+# This code is based on TripoSG (https://github.com/VAST-AI-Research/TripoSG). Below is the statement from the original repository:
+
+# This code is based on Tencent HunyuanDiT (https://huggingface.co/Tencent-Hunyuan/HunyuanDiT),
+# which is licensed under the Tencent Hunyuan Community License Agreement.
+# Portions of this code are copied or adapted from HunyuanDiT.
+# See the original license below:
+
+# ---- Start of Tencent Hunyuan Community License Agreement ----
+
+# TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
+# Tencent Hunyuan DiT Release Date: 14 May 2024
+# THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
+# By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
+# 1. DEFINITIONS.
+# a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
+# b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
+# c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
+# d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
+# e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
+# f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
+# g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
+# h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
+# i. “Tencent,” “We” or “Us” shall mean THL A29 Limited.
+# j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent Hunyuan DiT released at https://huggingface.co/Tencent-Hunyuan/HunyuanDiT.
+# k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
+# l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union.
+# m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
+# n. “including” shall mean including but not limited to.
+# 2. GRANT OF RIGHTS.
+# We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
+# 3. DISTRIBUTION.
+# You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
+# a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
+# b. You must cause any modified files to carry prominent notices stating that You changed the files;
+# c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
+# d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2024 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
+# You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
+# 4. ADDITIONAL COMMERCIAL TERMS.
+# If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
+# 5. RULES OF USE.
+# a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
+# b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other large language model (other than Tencent Hunyuan or Model Derivatives thereof).
+# c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
+# 6. INTELLECTUAL PROPERTY.
+# a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
+# b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
+# c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
+# d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
+# 7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
+# a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
+# b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
+# c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
+# 8. SURVIVAL AND TERMINATION.
+# a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
+# b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
+# 9. GOVERNING LAW AND JURISDICTION.
+# a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
+# b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
+#
+# EXHIBIT A
+# ACCEPTABLE USE POLICY
+
+# Tencent reserves the right to update this Acceptable Use Policy from time to time.
+# Last modified: [insert date]
+
+# Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
+# 1. Outside the Territory;
+# 2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
+# 3. To harm Yourself or others;
+# 4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
+# 5. To override or circumvent the safety guardrails and safeguards We have put in place;
+# 6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
+# 7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
+# 8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
+# 9. To intentionally defame, disparage or otherwise harass others;
+# 10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
+# 11. To generate or disseminate personal identifiable information with the purpose of harming others;
+# 12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
+# 13. To impersonate another individual without consent, authorization, or legal right;
+# 14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
+# 15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
+# 16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
+# 17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
+# 18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
+# 19. For military purposes;
+# 20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
+
+# ---- End of Tencent Hunyuan Community License Agreement ----
+
+# Please note that the use of this code is subject to the terms and conditions
+# of the Tencent Hunyuan Community License Agreement, including the Acceptable Use Policy.
+
+from typing import *
+
+import torch
+import torch.utils.checkpoint
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import PeftAdapterMixin
+from diffusers.models.attention import FeedForward
+from diffusers.models.attention_processor import Attention, AttentionProcessor
+from diffusers.models.embeddings import (
+ GaussianFourierProjection,
+ TimestepEmbedding,
+ Timesteps,
+)
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.normalization import (
+ AdaLayerNormContinuous,
+ FP32LayerNorm,
+ LayerNorm,
+)
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ is_torch_version,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from torch import nn
+
+from ..attention_processor import FusedTripoSGAttnProcessor2_0, TripoSGAttnProcessor2_0, PartCrafterAttnProcessor
+from .modeling_outputs import Transformer1DModelOutput
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@maybe_allow_in_graph
+class DiTBlock(nn.Module):
+ r"""
+ Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and
+ QKNorm
+
+ Parameters:
+ dim (`int`):
+ The number of channels in the input and output.
+ num_attention_heads (`int`):
+ The number of headsto use for multi-head attention.
+ cross_attention_dim (`int`,*optional*):
+ The size of the encoder_hidden_states vector for cross attention.
+ dropout(`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ activation_fn (`str`,*optional*, defaults to `"geglu"`):
+ Activation function to be used in feed-forward. .
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_eps (`float`, *optional*, defaults to 1e-6):
+ A small constant added to the denominator in normalization layers to prevent division by zero.
+ final_dropout (`bool` *optional*, defaults to False):
+ Whether to apply a final dropout after the last feed-forward layer.
+ ff_inner_dim (`int`, *optional*):
+ The size of the hidden layer in the feed-forward block. Defaults to `None`.
+ ff_bias (`bool`, *optional*, defaults to `True`):
+ Whether to use bias in the feed-forward block.
+ skip (`bool`, *optional*, defaults to `False`):
+ Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks.
+ qk_norm (`bool`, *optional*, defaults to `True`):
+ Whether to use normalization in QK calculation. Defaults to `True`.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ use_self_attention: bool = True,
+ self_attention_norm_type: Optional[str] = None,
+ use_cross_attention: bool = True, # ada layer norm
+ cross_attention_dim: Optional[int] = None,
+ cross_attention_norm_type: Optional[str] = "fp32_layer_norm",
+ dropout=0.0,
+ activation_fn: str = "gelu",
+ norm_type: str = "fp32_layer_norm", # TODO
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ final_dropout: bool = False,
+ ff_inner_dim: Optional[int] = None, # int(dim * 4) if None
+ ff_bias: bool = True,
+ skip: bool = False,
+ skip_concat_front: bool = False, # [x, skip] or [skip, x]
+ skip_norm_last: bool = False, # this is an error
+ qk_norm: bool = True,
+ qkv_bias: bool = True,
+ ):
+ super().__init__()
+
+ self.use_self_attention = use_self_attention
+ self.use_cross_attention = use_cross_attention
+ self.skip_concat_front = skip_concat_front
+ self.skip_norm_last = skip_norm_last
+ # Define 3 blocks. Each block has its own normalization layer.
+ # NOTE: when new version comes, check norm2 and norm 3
+ # 1. Self-Attn
+ if use_self_attention:
+ if (
+ self_attention_norm_type == "fp32_layer_norm"
+ or self_attention_norm_type is None
+ ):
+ self.norm1 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
+ else:
+ raise NotImplementedError
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ cross_attention_dim=None,
+ dim_head=dim // num_attention_heads,
+ heads=num_attention_heads,
+ qk_norm="rms_norm" if qk_norm else None,
+ eps=1e-6,
+ bias=qkv_bias,
+ processor=TripoSGAttnProcessor2_0(),
+ )
+
+ # 2. Cross-Attn
+ if use_cross_attention:
+ assert cross_attention_dim is not None
+
+ self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
+
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ dim_head=dim // num_attention_heads,
+ heads=num_attention_heads,
+ qk_norm="rms_norm" if qk_norm else None,
+ cross_attention_norm=cross_attention_norm_type,
+ eps=1e-6,
+ bias=qkv_bias,
+ processor=TripoSGAttnProcessor2_0(),
+ )
+
+ # 3. Feed-forward
+ self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout, ### 0.0
+ activation_fn=activation_fn, ### approx GeLU
+ final_dropout=final_dropout, ### 0.0
+ inner_dim=ff_inner_dim, ### int(dim * mlp_ratio)
+ bias=ff_bias,
+ )
+
+ # 4. Skip Connection
+ if skip:
+ self.skip_norm = FP32LayerNorm(dim, norm_eps, elementwise_affine=True)
+ self.skip_linear = nn.Linear(2 * dim, dim)
+ else:
+ self.skip_linear = None
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ def set_topk(self, topk):
+ self.flash_processor.topk = topk
+
+ def set_flash_processor(self, flash_processor):
+ self.flash_processor = flash_processor
+ self.attn2.processor = self.flash_processor
+
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ skip: Optional[torch.Tensor] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> torch.Tensor:
+ # Prepare attention kwargs
+ attention_kwargs = attention_kwargs or {}
+
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 0. Long Skip Connection
+ if self.skip_linear is not None:
+ cat = torch.cat(
+ (
+ [skip, hidden_states]
+ if self.skip_concat_front
+ else [hidden_states, skip]
+ ),
+ dim=-1,
+ )
+ if self.skip_norm_last:
+ # don't do this
+ hidden_states = self.skip_linear(cat)
+ hidden_states = self.skip_norm(hidden_states)
+ else:
+ cat = self.skip_norm(cat)
+ hidden_states = self.skip_linear(cat)
+
+ # 1. Self-Attention
+ if self.use_self_attention:
+ norm_hidden_states = self.norm1(hidden_states)
+ attn_output = self.attn1(
+ norm_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **attention_kwargs,
+ )
+ hidden_states = hidden_states + attn_output
+
+ # 2. Cross-Attention
+ if self.use_cross_attention:
+ hidden_states = hidden_states + self.attn2(
+ self.norm2(hidden_states),
+ encoder_hidden_states=encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **attention_kwargs,
+ )
+
+ # FFN Layer ### TODO: switch norm2 and norm3 in the state dict
+ mlp_inputs = self.norm3(hidden_states)
+ hidden_states = hidden_states + self.ff(mlp_inputs)
+
+ return hidden_states
+
+# Modified from https://github.com/VAST-AI-Research/TripoSG/blob/main/triposg/models/transformers/triposg_transformer.py#L365
+class PartCrafterDiTModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+ """
+ TripoSG: Diffusion model with a Transformer backbone.
+
+ Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88):
+ The number of channels in each head.
+ in_channels (`int`, *optional*):
+ The number of channels in the input and output (specify if the input is **continuous**).
+ patch_size (`int`, *optional*):
+ The size of the patch to use for the input.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
+ Activation function to use in feed-forward.
+ sample_size (`int`, *optional*):
+ The width of the latent images. This is fixed during training since it is used to learn a number of
+ position embeddings.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ cross_attention_dim (`int`, *optional*):
+ The number of dimension in the clip text embedding.
+ hidden_size (`int`, *optional*):
+ The size of hidden layer in the conditioning embedding layers.
+ num_layers (`int`, *optional*, defaults to 1):
+ The number of layers of Transformer blocks to use.
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
+ The ratio of the hidden layer size to the input size.
+ learn_sigma (`bool`, *optional*, defaults to `True`):
+ Whether to predict variance.
+ cross_attention_dim_t5 (`int`, *optional*):
+ The number dimensions in t5 text embedding.
+ pooled_projection_dim (`int`, *optional*):
+ The size of the pooled projection.
+ text_len (`int`, *optional*):
+ The length of the clip text embedding.
+ text_len_t5 (`int`, *optional*):
+ The length of the T5 text embedding.
+ use_style_cond_and_image_meta_size (`bool`, *optional*):
+ Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ width: int = 2048,
+ in_channels: int = 64,
+ num_layers: int = 21,
+ cross_attention_dim: int = 1024,
+ max_num_parts: int = 32,
+ enable_part_embedding=True,
+ enable_local_cross_attn: bool = True,
+ enable_global_cross_attn: bool = True,
+ global_attn_block_ids: Optional[List[int]] = None,
+ global_attn_block_id_range: Optional[List[int]] = None,
+ ):
+ super().__init__()
+ self.out_channels = in_channels
+ self.num_heads = num_attention_heads
+ self.inner_dim = width
+ self.mlp_ratio = 4.0
+
+ time_embed_dim, timestep_input_dim = self._set_time_proj(
+ "positional",
+ inner_dim=self.inner_dim,
+ flip_sin_to_cos=False,
+ freq_shift=0,
+ time_embedding_dim=None,
+ )
+ self.time_proj = TimestepEmbedding(
+ timestep_input_dim, time_embed_dim, act_fn="gelu", out_dim=self.inner_dim
+ )
+
+ if enable_part_embedding:
+ self.part_embedding = nn.Embedding(max_num_parts, self.inner_dim)
+ self.part_embedding.weight.data.normal_(mean=0.0, std=0.02)
+ self.enable_part_embedding = enable_part_embedding
+
+ self.proj_in = nn.Linear(self.config.in_channels, self.inner_dim, bias=True)
+
+ self.blocks = nn.ModuleList(
+ [
+ DiTBlock(
+ dim=self.inner_dim,
+ num_attention_heads=self.config.num_attention_heads,
+ use_self_attention=True,
+ self_attention_norm_type="fp32_layer_norm",
+ use_cross_attention=True,
+ cross_attention_dim=cross_attention_dim,
+ cross_attention_norm_type=None,
+ activation_fn="gelu",
+ norm_type="fp32_layer_norm", # TODO
+ norm_eps=1e-5,
+ ff_inner_dim=int(self.inner_dim * self.mlp_ratio),
+ skip=layer > num_layers // 2,
+ skip_concat_front=True,
+ skip_norm_last=True, # this is an error
+ qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
+ qkv_bias=False,
+ )
+ for layer in range(num_layers)
+ ]
+ )
+
+ self.norm_out = LayerNorm(self.inner_dim)
+ self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=True)
+
+ self.gradient_checkpointing = False
+
+ self.enable_local_cross_attn = enable_local_cross_attn
+ self.enable_global_cross_attn = enable_global_cross_attn
+
+ if global_attn_block_ids is None:
+ global_attn_block_ids = []
+ if global_attn_block_id_range is not None:
+ global_attn_block_ids = list(range(global_attn_block_id_range[0], global_attn_block_id_range[1] + 1))
+ self.global_attn_block_ids = global_attn_block_ids
+
+ if len(global_attn_block_ids) > 0:
+ # Override self-attention processors for global attention blocks
+ attn_processor_dict = {}
+ modified_attn_processor = []
+ for layer_id in range(num_layers):
+ for attn_id in [1, 2]:
+ if layer_id in global_attn_block_ids:
+ # apply to both self-attention and cross-attention
+ attn_processor_dict[f'blocks.{layer_id}.attn{attn_id}.processor'] = PartCrafterAttnProcessor()
+ modified_attn_processor.append(f'blocks.{layer_id}.attn{attn_id}.processor')
+ else:
+ attn_processor_dict[f'blocks.{layer_id}.attn{attn_id}.processor'] = TripoSGAttnProcessor2_0()
+ self.set_attn_processor(attn_processor_dict)
+ # logger.info(f"Modified {modified_attn_processor} to PartCrafterAttnProcessor")
+
+ def _set_gradient_checkpointing(
+ self,
+ enable: bool = False,
+ gradient_checkpointing_func: Optional[Callable] = None,
+ ):
+ # TODO: implement gradient checkpointing
+ self.gradient_checkpointing = enable
+
+ def _set_time_proj(
+ self,
+ time_embedding_type: str,
+ inner_dim: int,
+ flip_sin_to_cos: bool,
+ freq_shift: float,
+ time_embedding_dim: int,
+ ) -> Tuple[int, int]:
+ if time_embedding_type == "fourier":
+ time_embed_dim = time_embedding_dim or inner_dim * 2
+ if time_embed_dim % 2 != 0:
+ raise ValueError(
+ f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
+ )
+ self.time_embed = GaussianFourierProjection(
+ time_embed_dim // 2,
+ set_W_to_weight=False,
+ log=False,
+ flip_sin_to_cos=flip_sin_to_cos,
+ )
+ timestep_input_dim = time_embed_dim
+ elif time_embedding_type == "positional":
+ time_embed_dim = time_embedding_dim or inner_dim * 4
+
+ self.time_embed = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
+ timestep_input_dim = inner_dim
+ else:
+ raise ValueError(
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
+ )
+
+ return time_embed_dim, timestep_input_dim
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedTripoSGAttnProcessor2_0
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError(
+ "`fuse_qkv_projections()` is not supported for models having added KV projections."
+ )
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ self.set_attn_processor(FusedTripoSGAttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(
+ name: str,
+ module: torch.nn.Module,
+ processors: Dict[str, AttentionProcessor],
+ ):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
+ ):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ self.set_attn_processor(TripoSGAttnProcessor2_0())
+
+ def forward(
+ self,
+ hidden_states: Optional[torch.Tensor],
+ timestep: Union[int, float, torch.LongTensor],
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ):
+ """
+ The [`HunyuanDiT2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
+ The input tensor.
+ timestep ( `torch.LongTensor`, *optional*):
+ Used to indicate denoising step.
+ encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
+ Conditional embeddings for cross attention layer.
+ return_dict: bool
+ Whether to return a dictionary.
+ """
+
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if (
+ attention_kwargs is not None
+ and attention_kwargs.get("scale", None) is not None
+ ):
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ _, T, _ = hidden_states.shape
+
+ temb = self.time_embed(timestep).to(hidden_states.dtype)
+ temb = self.time_proj(temb)
+ temb = temb.unsqueeze(dim=1) # unsqueeze to concat with hidden_states
+
+ hidden_states = self.proj_in(hidden_states)
+
+ # T + 1 token
+ hidden_states = torch.cat([temb, hidden_states], dim=1) # (N, T+1, D)
+
+ if self.enable_part_embedding:
+ # Add part embedding
+ num_parts = attention_kwargs["num_parts"]
+ if isinstance(num_parts, torch.Tensor):
+ part_embeddings = []
+ for num_part in num_parts:
+ part_embedding = self.part_embedding(torch.arange(num_part, device=hidden_states.device)) # (n, D)
+ part_embeddings.append(part_embedding)
+ part_embedding = torch.cat(part_embeddings, dim=0) # (N, D)
+ elif isinstance(num_parts, int):
+ part_embedding = self.part_embedding(torch.arange(hidden_states.shape[0], device=hidden_states.device)) # (N, D)
+ else:
+ raise ValueError(
+ "num_parts must be a torch.Tensor or int, but got {}".format(type(num_parts))
+ )
+ hidden_states = hidden_states + part_embedding.unsqueeze(dim=1) # (N, T+1, D)
+
+ # prepare negative encoder_hidden_states
+ negative_encoder_hidden_states = torch.zeros_like(encoder_hidden_states) if encoder_hidden_states is not None else None
+
+ skips = []
+ for layer, block in enumerate(self.blocks):
+ skip = None if layer <= self.config.num_layers // 2 else skips.pop()
+ if (
+ (not self.enable_local_cross_attn)
+ and len(self.global_attn_block_ids) > 0
+ and (layer not in self.global_attn_block_ids)
+ ):
+ # If in non-global attention block and disable local cross attention, use negative encoder_hidden_states
+ # Do not inject control signal into non-global attention block
+ input_encoder_hidden_states = negative_encoder_hidden_states
+ elif (
+ (not self.enable_global_cross_attn)
+ and len(self.global_attn_block_ids) > 0
+ and (layer in self.global_attn_block_ids)
+ ):
+ # If in global attention block and disable global cross attention, use negative encoder_hidden_states
+ # Do not inject control signal into global attention block
+ input_encoder_hidden_states = negative_encoder_hidden_states
+ else:
+ input_encoder_hidden_states = encoder_hidden_states
+
+ if len(self.global_attn_block_ids) > 0 and (layer in self.global_attn_block_ids):
+ # Inject control signal into global attention block
+ input_attention_kwargs = attention_kwargs
+ else:
+ input_attention_kwargs = None
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = (
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ )
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ input_encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ skip,
+ input_attention_kwargs,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states=input_encoder_hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ skip=skip,
+ attention_kwargs=input_attention_kwargs,
+ ) # (N, T+1, D)
+
+ if layer < self.config.num_layers // 2:
+ skips.append(hidden_states)
+
+ # final layer
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = hidden_states[:, -T:] # (N, T, D)
+ hidden_states = self.proj_out(hidden_states)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (hidden_states,)
+
+ return Transformer1DModelOutput(sample=hidden_states)
+
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
+ def enable_forward_chunking(
+ self, chunk_size: Optional[int] = None, dim: int = 0
+ ) -> None:
+ """
+ Sets the attention processor to use [feed forward
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
+
+ Parameters:
+ chunk_size (`int`, *optional*):
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
+ over each tensor of dim=`dim`.
+ dim (`int`, *optional*, defaults to `0`):
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
+ or dim=1 (sequence length).
+ """
+ if dim not in [0, 1]:
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
+
+ # By default chunk size is 1
+ chunk_size = chunk_size or 1
+
+ def fn_recursive_feed_forward(
+ module: torch.nn.Module, chunk_size: int, dim: int
+ ):
+ if hasattr(module, "set_chunk_feed_forward"):
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
+
+ for child in module.children():
+ fn_recursive_feed_forward(child, chunk_size, dim)
+
+ for module in self.children():
+ fn_recursive_feed_forward(module, chunk_size, dim)
+
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
+ def disable_forward_chunking(self):
+ def fn_recursive_feed_forward(
+ module: torch.nn.Module, chunk_size: int, dim: int
+ ):
+ if hasattr(module, "set_chunk_feed_forward"):
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
+
+ for child in module.children():
+ fn_recursive_feed_forward(child, chunk_size, dim)
+
+ for module in self.children():
+ fn_recursive_feed_forward(module, None, 0)
diff --git a/src/pipelines/pipeline_partcrafter.py b/src/pipelines/pipeline_partcrafter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cff35c79b82bf83253f0ad9aa10362a0268670ec
--- /dev/null
+++ b/src/pipelines/pipeline_partcrafter.py
@@ -0,0 +1,355 @@
+import inspect
+import math
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL
+import PIL.Image
+import torch
+import trimesh
+from diffusers.image_processor import PipelineImageInput
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import logging
+from diffusers.utils.torch_utils import randn_tensor
+from transformers import (
+ BitImageProcessor,
+ Dinov2Model,
+)
+from ..utils.inference_utils import hierarchical_extract_geometry, flash_extract_geometry
+
+from ..models.autoencoders import TripoSGVAEModel
+from ..models.transformers import PartCrafterDiTModel
+from .pipeline_partcrafter_output import PartCrafterPipelineOutput
+from .pipeline_utils import TransformerDiffusionMixin
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError(
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
+ )
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
+ )
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
+ )
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class PartCrafterPipeline(DiffusionPipeline, TransformerDiffusionMixin):
+ """
+ Pipeline for image to 3D part-level object generation.
+ """
+
+ def __init__(
+ self,
+ vae: TripoSGVAEModel,
+ transformer: PartCrafterDiTModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ image_encoder_dinov2: Dinov2Model,
+ feature_extractor_dinov2: BitImageProcessor,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_encoder_dinov2=image_encoder_dinov2,
+ feature_extractor_dinov2=feature_extractor_dinov2,
+ )
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def decode_progressive(self):
+ return self._decode_progressive
+
+ def encode_image(self, image, device, num_images_per_prompt):
+ dtype = next(self.image_encoder_dinov2.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor_dinov2(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeds = self.image_encoder_dinov2(image).last_hidden_state
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_tokens,
+ num_channels_latents,
+ dtype,
+ device,
+ generator,
+ latents: Optional[torch.Tensor] = None,
+ ):
+ shape = (batch_size, num_tokens, num_channels_latents)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return noise
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ num_inference_steps: int = 50,
+ num_tokens: int = 2048,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.0,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ bounds: Union[Tuple[float], List[float], float] = (-1.005, -1.005, -1.005, 1.005, 1.005, 1.005),
+ dense_octree_depth: int = 8,
+ hierarchical_octree_depth: int = 9,
+ max_num_expanded_coords: int = 1e8,
+ flash_octree_depth: int = 9,
+ use_flash_decoder: bool = True,
+ return_dict: bool = True,
+ ):
+ # 1. Define call parameters
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if isinstance(image, PIL.Image.Image):
+ batch_size = 1
+ elif isinstance(image, list):
+ batch_size = len(image)
+ elif isinstance(image, torch.Tensor):
+ batch_size = image.shape[0]
+ else:
+ raise ValueError("Invalid input type for image")
+
+ device = self._execution_device
+ dtype = self.image_encoder_dinov2.dtype
+
+ # 3. Encode condition
+ image_embeds, negative_image_embeds = self.encode_image(
+ image, device, num_images_per_prompt
+ )
+
+ if self.do_classifier_free_guidance:
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps
+ )
+ num_warmup_steps = max(
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
+ )
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_tokens,
+ num_channels_latents,
+ image_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Denoising loop
+ self.set_progress_bar_config(
+ desc="Denoising",
+ ncols=125,
+ disable=self._progress_bar_config['disable'] if hasattr(self, '_progress_bar_config') else False,
+ )
+ with self.progress_bar(total=len(timesteps)) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = (
+ torch.cat([latents] * 2)
+ if self.do_classifier_free_guidance
+ else latents
+ )
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ noise_pred = self.transformer(
+ latent_model_input,
+ timestep,
+ encoder_hidden_states=image_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0].to(dtype)
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_image = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
+ noise_pred_image - noise_pred_uncond
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(
+ noise_pred, t, latents, return_dict=False
+ )[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ image_embeds_1 = callback_outputs.pop(
+ "image_embeds_1", image_embeds_1
+ )
+ negative_image_embeds_1 = callback_outputs.pop(
+ "negative_image_embeds_1", negative_image_embeds_1
+ )
+ image_embeds_2 = callback_outputs.pop(
+ "image_embeds_2", image_embeds_2
+ )
+ negative_image_embeds_2 = callback_outputs.pop(
+ "negative_image_embeds_2", negative_image_embeds_2
+ )
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or (
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
+ ):
+ progress_bar.update()
+
+
+ # 7. decoder mesh
+ self.vae.set_flash_decoder()
+ output, meshes = [], []
+ self.set_progress_bar_config(
+ desc="Decoding",
+ ncols=125,
+ disable=self._progress_bar_config['disable'] if hasattr(self, '_progress_bar_config') else False,
+ )
+ with self.progress_bar(total=batch_size) as progress_bar:
+ for i in range(batch_size):
+ geometric_func = lambda x: self.vae.decode(latents[i].unsqueeze(0), sampled_points=x).sample
+ try:
+ mesh_v_f = hierarchical_extract_geometry(
+ geometric_func,
+ device,
+ dtype=latents.dtype,
+ bounds=bounds,
+ dense_octree_depth=dense_octree_depth,
+ hierarchical_octree_depth=hierarchical_octree_depth,
+ max_num_expanded_coords=max_num_expanded_coords,
+ # verbose=True
+ )
+ mesh = trimesh.Trimesh(mesh_v_f[0].astype(np.float32), mesh_v_f[1])
+ except:
+ mesh_v_f = None
+ mesh = None
+ output.append(mesh_v_f)
+ meshes.append(mesh)
+ progress_bar.update()
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (output, meshes)
+
+ return PartCrafterPipelineOutput(samples=output, meshes=meshes)
+
diff --git a/src/pipelines/pipeline_partcrafter_output.py b/src/pipelines/pipeline_partcrafter_output.py
new file mode 100644
index 0000000000000000000000000000000000000000..656e6b72be0eb64e6420c76818388ed3e6ab22b5
--- /dev/null
+++ b/src/pipelines/pipeline_partcrafter_output.py
@@ -0,0 +1,17 @@
+from dataclasses import dataclass
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+import trimesh
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class PartCrafterPipelineOutput(BaseOutput):
+ r"""
+ Output class for ShapeDiff pipelines.
+ """
+
+ samples: torch.Tensor
+ meshes: List[trimesh.Trimesh]
diff --git a/src/pipelines/pipeline_utils.py b/src/pipelines/pipeline_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2dc4413b7c859e0dc88ade2aa12629414026274
--- /dev/null
+++ b/src/pipelines/pipeline_utils.py
@@ -0,0 +1,96 @@
+from diffusers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+class TransformerDiffusionMixin:
+ r"""
+ Helper for DiffusionPipeline with vae and transformer.(mainly for DIT)
+ """
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def fuse_qkv_projections(self, transformer: bool = True, vae: bool = True):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ Args:
+ transformer (`bool`, defaults to `True`): To apply fusion on the Transformer.
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
+ """
+ self.fusing_transformer = False
+ self.fusing_vae = False
+
+ if transformer:
+ self.fusing_transformer = True
+ self.transformer.fuse_qkv_projections()
+
+ if vae:
+ self.fusing_vae = True
+ self.vae.fuse_qkv_projections()
+
+ def unfuse_qkv_projections(self, transformer: bool = True, vae: bool = True):
+ """Disable QKV projection fusion if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ Args:
+ transformer (`bool`, defaults to `True`): To apply fusion on the Transformer.
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
+
+ """
+ if transformer:
+ if not self.fusing_transformer:
+ logger.warning(
+ "The UNet was not initially fused for QKV projections. Doing nothing."
+ )
+ else:
+ self.transformer.unfuse_qkv_projections()
+ self.fusing_transformer = False
+
+ if vae:
+ if not self.fusing_vae:
+ logger.warning(
+ "The VAE was not initially fused for QKV projections. Doing nothing."
+ )
+ else:
+ self.vae.unfuse_qkv_projections()
+ self.fusing_vae = False
diff --git a/src/schedulers/__init__.py b/src/schedulers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..619c3313f7b568c7500aeafa1ed91ad1b85400e5
--- /dev/null
+++ b/src/schedulers/__init__.py
@@ -0,0 +1,5 @@
+from .scheduling_rectified_flow import (
+ RectifiedFlowScheduler,
+ compute_density_for_timestep_sampling,
+ compute_loss_weighting,
+)
diff --git a/src/schedulers/scheduling_rectified_flow.py b/src/schedulers/scheduling_rectified_flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..d97c1bb1d833165a6150b1e5173232a49e93b7ea
--- /dev/null
+++ b/src/schedulers/scheduling_rectified_flow.py
@@ -0,0 +1,327 @@
+"""
+Adapted from https://github.com/huggingface/diffusers/blob/v0.30.3/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py.
+"""
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.schedulers.scheduling_utils import SchedulerMixin
+from diffusers.utils import BaseOutput, logging
+from torch.distributions import LogisticNormal
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# TODO: may move to training_utils.py
+def compute_density_for_timestep_sampling(
+ weighting_scheme: str,
+ batch_size: int,
+ logit_mean: float = 0.0,
+ logit_std: float = 1.0,
+ mode_scale: float = None,
+):
+ if weighting_scheme == "logit_normal":
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
+ u = torch.normal(
+ mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu"
+ )
+ u = torch.nn.functional.sigmoid(u)
+ elif weighting_scheme == "logit_normal_dist":
+ u = (
+ LogisticNormal(loc=logit_mean, scale=logit_std)
+ .sample((batch_size,))[:, 0]
+ .to("cpu")
+ )
+ elif weighting_scheme == "mode":
+ u = torch.rand(size=(batch_size,), device="cpu")
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
+ else:
+ u = torch.rand(size=(batch_size,), device="cpu")
+ return u
+
+
+def compute_loss_weighting(weighting_scheme: str, sigmas=None):
+ """
+ Computes loss weighting scheme for SD3 training.
+
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
+
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
+ """
+ if weighting_scheme == "sigma_sqrt":
+ weighting = (sigmas**-2.0).float()
+ elif weighting_scheme == "cosmap":
+ bot = 1 - 2 * sigmas + 2 * sigmas**2
+ weighting = 2 / (math.pi * bot)
+ else:
+ weighting = torch.ones_like(sigmas)
+ return weighting
+
+
+@dataclass
+class RectifiedFlowSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ """
+
+ prev_sample: torch.FloatTensor
+
+
+class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin):
+ """
+ The rectified flow scheduler is a scheduler that is used to propagate the diffusion process in the rectified flow.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ timestep_spacing (`str`, defaults to `"linspace"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ shift (`float`, defaults to 1.0):
+ The shift value for the timestep schedule.
+ """
+
+ _compatibles = []
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ shift: float = 1.0,
+ use_dynamic_shifting: bool = False,
+ ):
+ # pre-compute timesteps and sigmas; no use in fact
+ # NOTE that shape diffusion sample timesteps randomly or in a distribution,
+ # instead of sampling from the pre-defined linspace
+ timesteps = np.array(
+ [
+ (1.0 - i / num_train_timesteps) * num_train_timesteps
+ for i in range(num_train_timesteps)
+ ]
+ )
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
+
+ sigmas = timesteps / num_train_timesteps
+ if not use_dynamic_shifting:
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
+ sigmas = self.time_shift(sigmas)
+
+ self.timesteps = sigmas * num_train_timesteps
+
+ self._step_index = None
+ self._begin_index = None
+
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
+
+ @property
+ def step_index(self):
+ """
+ The index counter for current timestep. It will increase 1 after each scheduler step.
+ """
+ return self._step_index
+
+ @property
+ def begin_index(self):
+ """
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+ """
+ return self._begin_index
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
+ def set_begin_index(self, begin_index: int = 0):
+ """
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
+
+ Args:
+ begin_index (`int`):
+ The begin index for the scheduler.
+ """
+ self._begin_index = begin_index
+
+ def _sigma_to_t(self, sigma):
+ return sigma * self.config.num_train_timesteps
+
+ def _t_to_sigma(self, timestep):
+ return timestep / self.config.num_train_timesteps
+
+ def time_shift_dynamic(self, mu: float, sigma: float, t: torch.Tensor):
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
+
+ def time_shift(self, t: torch.Tensor):
+ return self.config.shift * t / (1 + (self.config.shift - 1) * t)
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int = None,
+ device: Union[str, torch.device] = None,
+ sigmas: Optional[List[float]] = None,
+ mu: Optional[float] = None,
+ ):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+
+ if self.config.use_dynamic_shifting and mu is None:
+ raise ValueError(
+ " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
+ )
+
+ if sigmas is None:
+ self.num_inference_steps = num_inference_steps
+ timesteps = np.array(
+ [
+ (1.0 - i / num_inference_steps) * self.config.num_train_timesteps
+ for i in range(num_inference_steps)
+ ]
+ ) # different from the original code in SD3
+ sigmas = timesteps / self.config.num_train_timesteps
+
+ if self.config.use_dynamic_shifting:
+ sigmas = self.time_shift_dynamic(mu, 1.0, sigmas)
+ else:
+ sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
+
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
+ timesteps = sigmas * self.config.num_train_timesteps
+
+ self.timesteps = timesteps.to(device=device)
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
+
+ self._step_index = None
+ self._begin_index = None
+
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
+ if schedule_timesteps is None:
+ schedule_timesteps = self.timesteps
+
+ indices = (schedule_timesteps == timestep).nonzero()
+
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ pos = 1 if len(indices) > 1 else 0
+
+ return indices[pos].item()
+
+ def _init_step_index(self, timestep):
+ if self.begin_index is None:
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ self._step_index = self.index_for_timestep(timestep)
+ else:
+ self._step_index = self._begin_index
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ sample: torch.FloatTensor,
+ s_churn: float = 0.0,
+ s_tmin: float = 0.0,
+ s_tmax: float = float("inf"),
+ s_noise: float = 1.0,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[RectifiedFlowSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ A current instance of a sample created by the diffusion process.
+ s_churn (`float`):
+ s_tmin (`float`):
+ s_tmax (`float`):
+ s_noise (`float`, defaults to 1.0):
+ Scaling factor for noise added to the sample.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ return_dict (`bool`):
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
+ tuple.
+
+ Returns:
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
+ """
+
+ if (
+ isinstance(timestep, int)
+ or isinstance(timestep, torch.IntTensor)
+ or isinstance(timestep, torch.LongTensor)
+ ):
+ raise ValueError(
+ (
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
+ " one of the `scheduler.timesteps` as a timestep."
+ ),
+ )
+
+ if self.step_index is None:
+ self._init_step_index(timestep)
+
+ # Upcast to avoid precision issues when computing prev_sample
+ sample = sample.to(torch.float32)
+
+ sigma = self.sigmas[self.step_index]
+ sigma_next = self.sigmas[self.step_index + 1]
+
+ # Here different directions are used for the flow matching
+ prev_sample = sample + (sigma - sigma_next) * model_output
+
+ # Cast sample back to model compatible dtype
+ prev_sample = prev_sample.to(model_output.dtype)
+
+ # upon completion increase step index by one
+ self._step_index += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return RectifiedFlowSchedulerOutput(prev_sample=prev_sample)
+
+ def scale_noise(
+ self,
+ original_samples: torch.Tensor,
+ noise: torch.Tensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.Tensor:
+ """
+ Forward function for the noise scaling in the flow matching.
+ """
+ sigmas = self._t_to_sigma(timesteps.to(dtype=torch.float32))
+
+ while len(sigmas.shape) < len(original_samples.shape):
+ sigmas = sigmas.unsqueeze(-1)
+
+ return (1.0 - sigmas) * original_samples + sigmas * noise
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/src/train_partcrafter.py b/src/train_partcrafter.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b3aea662d0137734496d349219d7cfb7537ad8c
--- /dev/null
+++ b/src/train_partcrafter.py
@@ -0,0 +1,1007 @@
+import warnings
+warnings.filterwarnings("ignore") # ignore all warnings
+import diffusers.utils.logging as diffusion_logging
+diffusion_logging.set_verbosity_error() # ignore diffusers warnings
+
+from src.utils.typing_utils import *
+
+import os
+import argparse
+import logging
+import time
+import math
+import gc
+from packaging import version
+
+import trimesh
+from PIL import Image
+import numpy as np
+import wandb
+from tqdm import tqdm
+
+import torch
+import torch.nn.functional as tF
+import accelerate
+from accelerate import Accelerator
+from accelerate.logging import get_logger as get_accelerate_logger
+from accelerate import DataLoaderConfiguration, DeepSpeedPlugin
+from diffusers.training_utils import (
+ compute_density_for_timestep_sampling,
+ compute_loss_weighting_for_sd3
+)
+
+from transformers import (
+ BitImageProcessor,
+ Dinov2Model,
+)
+from src.schedulers import RectifiedFlowScheduler
+from src.models.autoencoders import TripoSGVAEModel
+from src.models.transformers import PartCrafterDiTModel
+from src.pipelines.pipeline_partcrafter import PartCrafterPipeline
+
+from src.datasets import (
+ ObjaversePartDataset,
+ BatchedObjaversePartDataset,
+ MultiEpochsDataLoader,
+ yield_forever
+)
+from src.utils.data_utils import get_colored_mesh_composition
+from src.utils.train_utils import (
+ MyEMAModel,
+ get_configs,
+ get_optimizer,
+ get_lr_scheduler,
+ save_experiment_params,
+ save_model_architecture,
+)
+from src.utils.render_utils import (
+ render_views_around_mesh,
+ render_normal_views_around_mesh,
+ make_grid_for_images_or_videos,
+ export_renderings
+)
+from src.utils.metric_utils import compute_cd_and_f_score_in_training
+
+def main():
+ PROJECT_NAME = "PartCrafter"
+
+ parser = argparse.ArgumentParser(
+ description="Train a diffusion model for 3D object generation",
+ )
+
+ parser.add_argument(
+ "--config",
+ type=str,
+ required=True,
+ help="Path to the config file"
+ )
+ parser.add_argument(
+ "--tag",
+ type=str,
+ default=None,
+ help="Tag that refers to the current experiment"
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="output",
+ help="Path to the output directory"
+ )
+ parser.add_argument(
+ "--resume_from_iter",
+ type=int,
+ default=None,
+ help="The iteration to load the checkpoint from"
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=0,
+ help="Seed for the PRNG"
+ )
+ parser.add_argument(
+ "--offline_wandb",
+ action="store_true",
+ help="Use offline WandB for experiment tracking"
+ )
+
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="The max iteration step for training"
+ )
+ parser.add_argument(
+ "--max_val_steps",
+ type=int,
+ default=2,
+ help="The max iteration step for validation"
+ )
+ parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=32,
+ help="The number of processed spawned by the batch provider"
+ )
+ parser.add_argument(
+ "--pin_memory",
+ action="store_true",
+ help="Pin memory for the data loader"
+ )
+
+ parser.add_argument(
+ "--use_ema",
+ action="store_true",
+ help="Use EMA model for training"
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ help="Scale lr with total batch size (base batch size: 256)"
+ )
+ parser.add_argument(
+ "--max_grad_norm",
+ type=float,
+ default=1.,
+ help="Max gradient norm for gradient clipping"
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass"
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="fp16",
+ choices=["no", "fp16", "bf16"],
+ help="Type of mixed precision training"
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help="Enable TF32 for faster training on Ampere GPUs"
+ )
+
+ parser.add_argument(
+ "--val_guidance_scales",
+ type=list,
+ nargs="+",
+ default=[7.0],
+ help="CFG scale used for validation"
+ )
+
+ parser.add_argument(
+ "--use_deepspeed",
+ action="store_true",
+ help="Use DeepSpeed for training"
+ )
+ parser.add_argument(
+ "--zero_stage",
+ type=int,
+ default=1,
+ choices=[1, 2, 3], # https://huggingface.co/docs/accelerate/usage_guides/deepspeed
+ help="ZeRO stage type for DeepSpeed"
+ )
+
+ parser.add_argument(
+ "--from_scratch",
+ action="store_true",
+ help="Train from scratch"
+ )
+ parser.add_argument(
+ "--load_pretrained_model",
+ type=str,
+ default=None,
+ help="Tag of a pretrained PartCrafterDiTModel in this project"
+ )
+ parser.add_argument(
+ "--load_pretrained_model_ckpt",
+ type=int,
+ default=-1,
+ help="Iteration of the pretrained PartCrafterDiTModel checkpoint"
+ )
+
+ # Parse the arguments
+ args, extras = parser.parse_known_args()
+ # Parse the config file
+ configs = get_configs(args.config, extras) # change yaml configs by `extras`
+
+ args.val_guidance_scales = [float(x[0]) if isinstance(x, list) else float(x) for x in args.val_guidance_scales]
+ if args.max_val_steps > 0:
+ # If enable validation, the max_val_steps must be a multiple of nrow
+ # Always keep validation batchsize 1
+ divider = configs["val"]["nrow"]
+ args.max_val_steps = max(args.max_val_steps, divider)
+ if args.max_val_steps % divider != 0:
+ args.max_val_steps = (args.max_val_steps // divider + 1) * divider
+
+ # Create an experiment directory using the `tag`
+ if args.tag is None:
+ args.tag = time.strftime("%Y%m%d_%H_%M_%S")
+ exp_dir = os.path.join(args.output_dir, args.tag)
+ ckpt_dir = os.path.join(exp_dir, "checkpoints")
+ eval_dir = os.path.join(exp_dir, "evaluations")
+ os.makedirs(ckpt_dir, exist_ok=True)
+ os.makedirs(eval_dir, exist_ok=True)
+
+ # Initialize the logger
+ logging.basicConfig(
+ format="%(asctime)s - %(message)s",
+ datefmt="%Y/%m/%d %H:%M:%S",
+ level=logging.INFO
+ )
+ logger = get_accelerate_logger(__name__, log_level="INFO")
+ file_handler = logging.FileHandler(os.path.join(exp_dir, "log.txt")) # output to file
+ file_handler.setFormatter(logging.Formatter(
+ fmt="%(asctime)s - %(message)s",
+ datefmt="%Y/%m/%d %H:%M:%S"
+ ))
+ logger.logger.addHandler(file_handler)
+ logger.logger.propagate = True # propagate to the root logger (console)
+
+ # Set DeepSpeed config
+ if args.use_deepspeed:
+ deepspeed_plugin = DeepSpeedPlugin(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ gradient_clipping=args.max_grad_norm,
+ zero_stage=int(args.zero_stage),
+ offload_optimizer_device="cpu", # hard-coded here, TODO: make it configurable
+ )
+ else:
+ deepspeed_plugin = None
+
+ # Initialize the accelerator
+ accelerator = Accelerator(
+ project_dir=exp_dir,
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ split_batches=False, # batch size per GPU
+ dataloader_config=DataLoaderConfiguration(non_blocking=args.pin_memory),
+ deepspeed_plugin=deepspeed_plugin,
+ )
+ logger.info(f"Accelerator state:\n{accelerator.state}\n")
+
+ # Set the random seed
+ if args.seed >= 0:
+ accelerate.utils.set_seed(args.seed)
+ logger.info(f"You have chosen to seed([{args.seed}]) the experiment [{args.tag}]\n")
+
+ # Enable TF32 for faster training on Ampere GPUs
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ train_dataset = BatchedObjaversePartDataset(
+ configs=configs,
+ batch_size=configs["train"]["batch_size_per_gpu"],
+ is_main_process=accelerator.is_main_process,
+ shuffle=True,
+ training=True,
+ )
+ val_dataset = ObjaversePartDataset(
+ configs=configs,
+ training=False,
+ )
+ train_loader = MultiEpochsDataLoader(
+ train_dataset,
+ batch_size=configs["train"]["batch_size_per_gpu"],
+ num_workers=args.num_workers,
+ drop_last=True,
+ pin_memory=args.pin_memory,
+ collate_fn=train_dataset.collate_fn,
+ )
+ val_loader = MultiEpochsDataLoader(
+ val_dataset,
+ batch_size=configs["val"]["batch_size_per_gpu"],
+ num_workers=args.num_workers,
+ drop_last=True,
+ pin_memory=args.pin_memory,
+ )
+ random_val_loader = MultiEpochsDataLoader(
+ val_dataset,
+ batch_size=configs["val"]["batch_size_per_gpu"],
+ shuffle=True,
+ num_workers=args.num_workers,
+ drop_last=True,
+ pin_memory=args.pin_memory,
+ )
+
+ logger.info(f"Loaded [{len(train_dataset)}] training samples and [{len(val_dataset)}] validation samples\n")
+
+ # Compute the effective batch size and scale learning rate
+ total_batch_size = configs["train"]["batch_size_per_gpu"] * \
+ accelerator.num_processes * args.gradient_accumulation_steps
+ configs["train"]["total_batch_size"] = total_batch_size
+ if args.scale_lr:
+ configs["optimizer"]["lr"] *= (total_batch_size / 256)
+ configs["lr_scheduler"]["max_lr"] = configs["optimizer"]["lr"]
+
+ # Initialize the model
+ logger.info("Initializing the model...")
+ vae = TripoSGVAEModel.from_pretrained(
+ configs["model"]["pretrained_model_name_or_path"],
+ subfolder="vae"
+ )
+ feature_extractor_dinov2 = BitImageProcessor.from_pretrained(
+ configs["model"]["pretrained_model_name_or_path"],
+ subfolder="feature_extractor_dinov2"
+ )
+ image_encoder_dinov2 = Dinov2Model.from_pretrained(
+ configs["model"]["pretrained_model_name_or_path"],
+ subfolder="image_encoder_dinov2"
+ )
+
+ enable_part_embedding = configs["model"]["transformer"].get("enable_part_embedding", True)
+ enable_local_cross_attn = configs["model"]["transformer"].get("enable_local_cross_attn", True)
+ enable_global_cross_attn = configs["model"]["transformer"].get("enable_global_cross_attn", True)
+ global_attn_block_ids = configs["model"]["transformer"].get("global_attn_block_ids", None)
+ if global_attn_block_ids is not None:
+ global_attn_block_ids = list(global_attn_block_ids)
+ global_attn_block_id_range = configs["model"]["transformer"].get("global_attn_block_id_range", None)
+ if global_attn_block_id_range is not None:
+ global_attn_block_id_range = list(global_attn_block_id_range)
+ if args.from_scratch:
+ logger.info(f"Initialize PartCrafterDiTModel from scratch\n")
+ transformer = PartCrafterDiTModel.from_config(
+ os.path.join(
+ configs["model"]["pretrained_model_name_or_path"],
+ "transformer"
+ ),
+ enable_part_embedding=enable_part_embedding,
+ enable_local_cross_attn=enable_local_cross_attn,
+ enable_global_cross_attn=enable_global_cross_attn,
+ global_attn_block_ids=global_attn_block_ids,
+ global_attn_block_id_range=global_attn_block_id_range,
+ )
+ elif args.load_pretrained_model is None:
+ logger.info(f"Load pretrained TripoSGDiTModel to initialize PartCrafterDiTModel from [{configs['model']['pretrained_model_name_or_path']}]\n")
+ transformer, loading_info = PartCrafterDiTModel.from_pretrained(
+ configs["model"]["pretrained_model_name_or_path"],
+ subfolder="transformer",
+ low_cpu_mem_usage=False,
+ output_loading_info=True,
+ enable_part_embedding=enable_part_embedding,
+ enable_local_cross_attn=enable_local_cross_attn,
+ enable_global_cross_attn=enable_global_cross_attn,
+ global_attn_block_ids=global_attn_block_ids,
+ global_attn_block_id_range=global_attn_block_id_range,
+ )
+ else:
+ logger.info(f"Load PartCrafterDiTModel EMA checkpoint from [{args.load_pretrained_model}] iteration [{args.load_pretrained_model_ckpt:06d}]\n")
+ path = os.path.join(
+ args.output_dir,
+ args.load_pretrained_model,
+ "checkpoints",
+ f"{args.load_pretrained_model_ckpt:06d}"
+ )
+ transformer, loading_info = PartCrafterDiTModel.from_pretrained(
+ path,
+ subfolder="transformer_ema",
+ low_cpu_mem_usage=False,
+ output_loading_info=True,
+ enable_part_embedding=enable_part_embedding,
+ enable_local_cross_attn=enable_local_cross_attn,
+ enable_global_cross_attn=enable_global_cross_attn,
+ global_attn_block_ids=global_attn_block_ids,
+ global_attn_block_id_range=global_attn_block_id_range,
+ )
+ if not args.from_scratch:
+ for v in loading_info.values():
+ if v and len(v) > 0:
+ logger.info(f"Loading info of PartCrafterDiTModel: {loading_info}\n")
+ break
+
+ noise_scheduler = RectifiedFlowScheduler.from_pretrained(
+ configs["model"]["pretrained_model_name_or_path"],
+ subfolder="scheduler"
+ )
+
+ if args.use_ema:
+ ema_transformer = MyEMAModel(
+ transformer.parameters(),
+ model_cls=PartCrafterDiTModel,
+ model_config=transformer.config,
+ **configs["train"]["ema_kwargs"]
+ )
+
+ # Freeze VAE and image encoder
+ vae.requires_grad_(False)
+ image_encoder_dinov2.requires_grad_(False)
+ vae.eval()
+ image_encoder_dinov2.eval()
+
+ trainable_modules = configs["train"].get("trainable_modules", None)
+ if trainable_modules is None:
+ transformer.requires_grad_(True)
+ else:
+ trainable_module_names = []
+ transformer.requires_grad_(False)
+ for name, module in transformer.named_modules():
+ for module_name in tuple(trainable_modules.split(",")):
+ if module_name in name:
+ for params in module.parameters():
+ params.requires_grad = True
+ trainable_module_names.append(name)
+ logger.info(f"Trainable parameter names: {trainable_module_names}\n")
+
+ # transformer.enable_xformers_memory_efficient_attention() # use `tF.scaled_dot_product_attention` instead
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # Create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_transformer.save_pretrained(os.path.join(output_dir, "transformer_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "transformer"))
+
+ # Make sure to pop weight so that corresponding model is not saved again
+ if weights:
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = MyEMAModel.from_pretrained(os.path.join(input_dir, "transformer_ema"), PartCrafterDiTModel)
+ ema_transformer.load_state_dict(load_model.state_dict())
+ ema_transformer.to(accelerator.device)
+ del load_model
+
+ for _ in range(len(models)):
+ # Pop models so that they are not loaded again
+ model = models.pop()
+
+ # Load diffusers style into model
+ load_model = PartCrafterDiTModel.from_pretrained(input_dir, subfolder="transformer")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if configs["train"]["grad_checkpoint"]:
+ transformer.enable_gradient_checkpointing()
+
+ # Initialize the optimizer and learning rate scheduler
+ logger.info("Initializing the optimizer and learning rate scheduler...\n")
+ name_lr_mult = configs["train"].get("name_lr_mult", None)
+ lr_mult = configs["train"].get("lr_mult", 1.0)
+ params, params_lr_mult, names_lr_mult = [], [], []
+ for name, param in transformer.named_parameters():
+ if name_lr_mult is not None:
+ for k in name_lr_mult.split(","):
+ if k in name:
+ params_lr_mult.append(param)
+ names_lr_mult.append(name)
+ if name not in names_lr_mult:
+ params.append(param)
+ else:
+ params.append(param)
+ optimizer = get_optimizer(
+ params=[
+ {"params": params, "lr": configs["optimizer"]["lr"]},
+ {"params": params_lr_mult, "lr": configs["optimizer"]["lr"] * lr_mult}
+ ],
+ **configs["optimizer"]
+ )
+ if name_lr_mult is not None:
+ logger.info(f"Learning rate x [{lr_mult}] parameter names: {names_lr_mult}\n")
+
+ configs["lr_scheduler"]["total_steps"] = configs["train"]["epochs"] * math.ceil(
+ len(train_loader) // accelerator.num_processes / args.gradient_accumulation_steps) # only account updated steps
+ configs["lr_scheduler"]["total_steps"] *= accelerator.num_processes # for lr scheduler setting
+ if "num_warmup_steps" in configs["lr_scheduler"]:
+ configs["lr_scheduler"]["num_warmup_steps"] *= accelerator.num_processes # for lr scheduler setting
+ lr_scheduler = get_lr_scheduler(optimizer=optimizer, **configs["lr_scheduler"])
+ configs["lr_scheduler"]["total_steps"] //= accelerator.num_processes # reset for multi-gpu
+ if "num_warmup_steps" in configs["lr_scheduler"]:
+ configs["lr_scheduler"]["num_warmup_steps"] //= accelerator.num_processes # reset for multi-gpu
+
+ # Prepare everything with `accelerator`
+ transformer, optimizer, lr_scheduler, train_loader, val_loader, random_val_loader = accelerator.prepare(
+ transformer, optimizer, lr_scheduler, train_loader, val_loader, random_val_loader
+ )
+ # Set classes explicitly for everything
+ transformer: DistributedDataParallel
+ optimizer: AcceleratedOptimizer
+ lr_scheduler: AcceleratedScheduler
+ train_loader: DataLoaderShard
+ val_loader: DataLoaderShard
+ random_val_loader: DataLoaderShard
+
+ if args.use_ema:
+ ema_transformer.to(accelerator.device)
+
+ # For mixed precision training we cast all non-trainable weigths to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move `vae` and `image_encoder_dinov2` to gpu and cast to `weight_dtype`
+ vae.to(accelerator.device, dtype=weight_dtype)
+ image_encoder_dinov2.to(accelerator.device, dtype=weight_dtype)
+
+ # Training configs after distribution and accumulation setup
+ updated_steps_per_epoch = math.ceil(len(train_loader) / args.gradient_accumulation_steps)
+ total_updated_steps = configs["lr_scheduler"]["total_steps"]
+ if args.max_train_steps is None:
+ args.max_train_steps = total_updated_steps
+ assert configs["train"]["epochs"] * updated_steps_per_epoch == total_updated_steps
+ if accelerator.num_processes > 1 and accelerator.is_main_process:
+ print()
+ accelerator.wait_for_everyone()
+ logger.info(f"Total batch size: [{total_batch_size}]")
+ logger.info(f"Learning rate: [{configs['optimizer']['lr']}]")
+ logger.info(f"Gradient Accumulation steps: [{args.gradient_accumulation_steps}]")
+ logger.info(f"Total epochs: [{configs['train']['epochs']}]")
+ logger.info(f"Total steps: [{total_updated_steps}]")
+ logger.info(f"Steps for updating per epoch: [{updated_steps_per_epoch}]")
+ logger.info(f"Steps for validation: [{len(val_loader)}]\n")
+
+ # (Optional) Load checkpoint
+ global_update_step = 0
+ if args.resume_from_iter is not None:
+ if args.resume_from_iter < 0:
+ args.resume_from_iter = int(sorted(os.listdir(ckpt_dir))[-1])
+ logger.info(f"Load checkpoint from iteration [{args.resume_from_iter}]\n")
+ # Load everything
+ if version.parse(torch.__version__) >= version.parse("2.4.0"):
+ torch.serialization.add_safe_globals([
+ int, list, dict,
+ defaultdict,
+ Any,
+ DictConfig, ListConfig, Metadata, ContainerMetadata, AnyNode
+ ]) # avoid deserialization error when loading optimizer state
+ accelerator.load_state(os.path.join(ckpt_dir, f"{args.resume_from_iter:06d}")) # torch < 2.4.0 here for `weights_only=False`
+ global_update_step = int(args.resume_from_iter)
+
+ # Save all experimental parameters and model architecture of this run to a file (args and configs)
+ if accelerator.is_main_process:
+ exp_params = save_experiment_params(args, configs, exp_dir)
+ save_model_architecture(accelerator.unwrap_model(transformer), exp_dir)
+
+ # WandB logger
+ if accelerator.is_main_process:
+ if args.offline_wandb:
+ os.environ["WANDB_MODE"] = "offline"
+ wandb.init(
+ project=PROJECT_NAME, name=args.tag,
+ config=exp_params, dir=exp_dir,
+ resume=True
+ )
+ # Wandb artifact for logging experiment information
+ arti_exp_info = wandb.Artifact(args.tag, type="exp_info")
+ arti_exp_info.add_file(os.path.join(exp_dir, "params.yaml"))
+ arti_exp_info.add_file(os.path.join(exp_dir, "model.txt"))
+ arti_exp_info.add_file(os.path.join(exp_dir, "log.txt")) # only save the log before training
+ wandb.log_artifact(arti_exp_info)
+
+ def get_sigmas(timesteps: Tensor, n_dim: int, dtype=torch.float32):
+ sigmas = noise_scheduler.sigmas.to(dtype=dtype, device=accelerator.device)
+ schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+
+ step_indices = [(schedule_timesteps == t).nonzero()[0].item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ # Start training
+ if accelerator.is_main_process:
+ print()
+ logger.info(f"Start training into {exp_dir}\n")
+ logger.logger.propagate = False # not propagate to the root logger (console)
+ progress_bar = tqdm(
+ range(total_updated_steps),
+ initial=global_update_step,
+ desc="Training",
+ ncols=125,
+ disable=not accelerator.is_main_process
+ )
+ for batch in yield_forever(train_loader):
+
+ if global_update_step == args.max_train_steps:
+ progress_bar.close()
+ logger.logger.propagate = True # propagate to the root logger (console)
+ if accelerator.is_main_process:
+ wandb.finish()
+ logger.info("Training finished!\n")
+ return
+
+ transformer.train()
+
+ with accelerator.accumulate(transformer):
+
+ images = batch["images"] # [N, H, W, 3]
+ with torch.no_grad():
+ images = feature_extractor_dinov2(images=images, return_tensors="pt").pixel_values
+ images = images.to(device=accelerator.device, dtype=weight_dtype)
+ with torch.no_grad():
+ image_embeds = image_encoder_dinov2(images).last_hidden_state
+ negative_image_embeds = torch.zeros_like(image_embeds)
+
+ part_surfaces = batch["part_surfaces"] # [N, P, 6]
+ part_surfaces = part_surfaces.to(device=accelerator.device, dtype=weight_dtype)
+
+ num_parts = batch["num_parts"] # [M, ] The shape of num_parts is not fixed
+ num_objects = num_parts.shape[0] # M
+
+ with torch.no_grad():
+ latents = vae.encode(
+ part_surfaces,
+ **configs["model"]["vae"]
+ ).latent_dist.sample()
+
+ noise = torch.randn_like(latents)
+ # For weighting schemes where we sample timesteps non-uniformly
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=configs["train"]["weighting_scheme"],
+ batch_size=num_objects,
+ logit_mean=configs["train"]["logit_mean"],
+ logit_std=configs["train"]["logit_std"],
+ mode_scale=configs["train"]["mode_scale"],
+ )
+ indices = (u * noise_scheduler.config.num_train_timesteps).long()
+ timesteps = noise_scheduler.timesteps[indices].to(accelerator.device) # [M, ]
+ # Repeat the timesteps for each part
+ timesteps = timesteps.repeat_interleave(num_parts) # [N, ]
+
+ sigmas = get_sigmas(timesteps, len(latents.shape), weight_dtype)
+ latent_model_input = noisy_latents = (1. - sigmas) * latents + sigmas * noise
+
+ if configs["train"]["cfg_dropout_prob"] > 0:
+ # We use the same dropout mask for the same part
+ dropout_mask = torch.rand(num_objects, device=accelerator.device) < configs["train"]["cfg_dropout_prob"] # [M, ]
+ dropout_mask = dropout_mask.repeat_interleave(num_parts) # [N, ]
+ if dropout_mask.any():
+ image_embeds[dropout_mask] = negative_image_embeds[dropout_mask]
+
+ model_pred = transformer(
+ hidden_states=latent_model_input,
+ timestep=timesteps,
+ encoder_hidden_states=image_embeds,
+ attention_kwargs={"num_parts": num_parts}
+ ).sample
+
+ if configs["train"]["training_objective"] == "x0": # Section 5 of https://arxiv.org/abs/2206.00364
+ model_pred = model_pred * (-sigmas) + noisy_latents # predicted x_0
+ target = latents
+ elif configs["train"]["training_objective"] == 'v': # flow matching
+ target = noise - latents
+ elif configs["train"]["training_objective"] == '-v': # reverse flow matching
+ # The training objective for TripoSG is the reverse of the flow matching objective.
+ # It uses "different directions", i.e., the negative velocity.
+ # This is probably a mistake in engineering, not very harmful.
+ # In TripoSG's rectified flow scheduler, prev_sample = sample + (sigma - sigma_next) * model_output
+ # See TripoSG's scheduler https://github.com/VAST-AI-Research/TripoSG/blob/main/triposg/schedulers/scheduling_rectified_flow.py#L296
+ # While in diffusers's flow matching scheduler, prev_sample = sample + (sigma_next - sigma) * model_output
+ # See https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L454
+ target = latents - noise
+ else:
+ raise ValueError(f"Unknown training objective [{configs['train']['training_objective']}]")
+
+ # For these weighting schemes use a uniform timestep sampling, so post-weight the loss
+ weighting = compute_loss_weighting_for_sd3(
+ configs["train"]["weighting_scheme"],
+ sigmas
+ )
+
+ loss = weighting * tF.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape))))
+
+ # Backpropagate
+ accelerator.backward(loss.mean())
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ # Gather the losses across all processes for logging (if we use distributed training)
+ loss = accelerator.gather(loss.detach()).mean()
+
+ logs = {
+ "loss": loss.item(),
+ "lr": lr_scheduler.get_last_lr()[0]
+ }
+ if args.use_ema:
+ ema_transformer.step(transformer.parameters())
+ logs.update({"ema": ema_transformer.cur_decay_value})
+
+ progress_bar.set_postfix(**logs)
+ progress_bar.update(1)
+ global_update_step += 1
+
+ logger.info(
+ f"[{global_update_step:06d} / {total_updated_steps:06d}] " +
+ f"loss: {logs['loss']:.4f}, lr: {logs['lr']:.2e}" +
+ f", ema: {logs['ema']:.4f}" if args.use_ema else ""
+ )
+
+ # Log the training progress
+ if (
+ global_update_step % configs["train"]["log_freq"] == 0
+ or global_update_step == 1
+ or global_update_step % updated_steps_per_epoch == 0 # last step of an epoch
+ ):
+ if accelerator.is_main_process:
+ wandb.log({
+ "training/loss": logs["loss"],
+ "training/lr": logs["lr"],
+ }, step=global_update_step)
+ if args.use_ema:
+ wandb.log({
+ "training/ema": logs["ema"]
+ }, step=global_update_step)
+
+ # Save checkpoint
+ if (
+ global_update_step % configs["train"]["save_freq"] == 0 # 1. every `save_freq` steps
+ or global_update_step % (configs["train"]["save_freq_epoch"] * updated_steps_per_epoch) == 0 # 2. every `save_freq_epoch` epochs
+ or global_update_step == total_updated_steps # 3. last step of an epoch
+ # or global_update_step == 1 # 4. first step
+ ):
+
+ gc.collect()
+ if accelerator.distributed_type == accelerate.utils.DistributedType.DEEPSPEED:
+ # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues
+ accelerator.save_state(os.path.join(ckpt_dir, f"{global_update_step:06d}"))
+ elif accelerator.is_main_process:
+ accelerator.save_state(os.path.join(ckpt_dir, f"{global_update_step:06d}"))
+ accelerator.wait_for_everyone() # ensure all processes have finished saving
+ gc.collect()
+
+ # Evaluate on the validation set
+ if args.max_val_steps > 0 and (
+ (global_update_step % configs["train"]["early_eval_freq"] == 0 and global_update_step < configs["train"]["early_eval"]) # 1. more frequently at the beginning
+ or global_update_step % configs["train"]["eval_freq"] == 0 # 2. every `eval_freq` steps
+ or global_update_step % (configs["train"]["eval_freq_epoch"] * updated_steps_per_epoch) == 0 # 3. every `eval_freq_epoch` epochs
+ or global_update_step == total_updated_steps # 4. last step of an epoch
+ or global_update_step == 1 # 5. first step
+ ):
+
+ # Use EMA parameters for evaluation
+ if args.use_ema:
+ # Store the Transformer parameters temporarily and load the EMA parameters to perform inference
+ ema_transformer.store(transformer.parameters())
+ ema_transformer.copy_to(transformer.parameters())
+
+ transformer.eval()
+
+ log_validation(
+ val_loader, random_val_loader,
+ feature_extractor_dinov2, image_encoder_dinov2,
+ vae, transformer,
+ global_update_step, eval_dir,
+ accelerator, logger,
+ args, configs
+ )
+
+ if args.use_ema:
+ # Switch back to the original Transformer parameters
+ ema_transformer.restore(transformer.parameters())
+
+ torch.cuda.empty_cache()
+ gc.collect()
+
+@torch.no_grad()
+def log_validation(
+ dataloader, random_dataloader,
+ feature_extractor_dinov2, image_encoder_dinov2,
+ vae, transformer,
+ global_step, eval_dir,
+ accelerator, logger,
+ args, configs
+):
+
+ val_noise_scheduler = RectifiedFlowScheduler.from_pretrained(
+ configs["model"]["pretrained_model_name_or_path"],
+ subfolder="scheduler"
+ )
+
+ pipeline = PartCrafterPipeline(
+ vae=vae,
+ transformer=accelerator.unwrap_model(transformer),
+ scheduler=val_noise_scheduler,
+ feature_extractor_dinov2=feature_extractor_dinov2,
+ image_encoder_dinov2=image_encoder_dinov2,
+ )
+
+ pipeline.set_progress_bar_config(disable=True)
+ # pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed >= 0:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ else:
+ generator = None
+
+
+ val_progress_bar = tqdm(
+ range(len(dataloader)) if args.max_val_steps is None else range(args.max_val_steps),
+ desc=f"Validation [{global_step:06d}]",
+ ncols=125,
+ disable=not accelerator.is_main_process
+ )
+
+ medias_dictlist, metrics_dictlist = defaultdict(list), defaultdict(list)
+
+ val_dataloder, random_val_dataloader = yield_forever(dataloader), yield_forever(random_dataloader)
+ val_step = 0
+ while val_step < args.max_val_steps:
+
+ if val_step < args.max_val_steps // 2:
+ # fix the first half
+ batch = next(val_dataloder)
+ else:
+ # randomly sample the next batch
+ batch = next(random_val_dataloader)
+
+ images = batch["images"]
+ if len(images.shape) == 5:
+ images = images[0] # (1, N, H, W, 3) -> (N, H, W, 3)
+ images = [Image.fromarray(image) for image in images.cpu().numpy()]
+ part_surfaces = batch["part_surfaces"].cpu().numpy()
+ if len(part_surfaces.shape) == 4:
+ part_surfaces = part_surfaces[0] # (1, N, P, 6) -> (N, P, 6)
+
+ N = len(images)
+
+ val_progress_bar.set_postfix(
+ {"num_parts": N}
+ )
+
+ with torch.autocast("cuda", torch.float16):
+ for guidance_scale in sorted(args.val_guidance_scales):
+ pred_part_meshes = pipeline(
+ images,
+ num_inference_steps=configs['val']['num_inference_steps'],
+ num_tokens=configs['model']['vae']['num_tokens'],
+ guidance_scale=guidance_scale,
+ attention_kwargs={"num_parts": N},
+ generator=generator,
+ max_num_expanded_coords=configs['val']['max_num_expanded_coords'],
+ use_flash_decoder=configs['val']['use_flash_decoder'],
+ ).meshes
+
+ # Save the generated meshes
+ if accelerator.is_main_process:
+ local_eval_dir = os.path.join(eval_dir, f"{global_step:06d}", f"guidance_scale_{guidance_scale:.1f}")
+ os.makedirs(local_eval_dir, exist_ok=True)
+ rendered_images_list, rendered_normals_list = [], []
+ # 1. save the gt image
+ images[0].save(os.path.join(local_eval_dir, f"{val_step:04d}.png"))
+ # 2. save the generated part meshes
+ for n in range(N):
+ if pred_part_meshes[n] is None:
+ # If the generated mesh is None (decoing error), use a dummy mesh
+ pred_part_meshes[n] = trimesh.Trimesh(vertices=[[0, 0, 0]], faces=[[0, 0, 0]])
+ pred_part_meshes[n].export(os.path.join(local_eval_dir, f"{val_step:04d}_{n:02d}.glb"))
+ # 3. render the generated mesh and save the rendered images
+ pred_mesh = get_colored_mesh_composition(pred_part_meshes)
+ rendered_images: List[Image.Image] = render_views_around_mesh(
+ pred_mesh,
+ num_views=configs['val']['rendering']['num_views'],
+ radius=configs['val']['rendering']['radius'],
+ )
+ rendered_normals: List[Image.Image] = render_normal_views_around_mesh(
+ pred_mesh,
+ num_views=configs['val']['rendering']['num_views'],
+ radius=configs['val']['rendering']['radius'],
+ )
+ export_renderings(
+ rendered_images,
+ os.path.join(local_eval_dir, f"{val_step:04d}.gif"),
+ fps=configs['val']['rendering']['fps']
+ )
+ export_renderings(
+ rendered_normals,
+ os.path.join(local_eval_dir, f"{val_step:04d}_normals.gif"),
+ fps=configs['val']['rendering']['fps']
+ )
+ rendered_images_list.append(rendered_images)
+ rendered_normals_list.append(rendered_normals)
+
+ medias_dictlist[f"guidance_scale_{guidance_scale:.1f}/gt_image"] += [images[0]] # List[Image.Image] TODO: support batch size > 1
+ medias_dictlist[f"guidance_scale_{guidance_scale:.1f}/pred_rendered_images"] += rendered_images_list # List[List[Image.Image]]
+ medias_dictlist[f"guidance_scale_{guidance_scale:.1f}/pred_rendered_normals"] += rendered_normals_list # List[List[Image.Image]]
+
+ ################################ Compute generation metrics ################################
+
+ parts_chamfer_distances, parts_f_scores = [], []
+
+ for n in range(N):
+ # gt_part_surface = part_surfaces[n]
+ # pred_part_mesh = pred_part_meshes[n]
+ # if pred_part_mesh is None:
+ # # If the generated mesh is None (decoing error), use a dummy mesh
+ # pred_part_mesh = trimesh.Trimesh(vertices=[[0, 0, 0]], faces=[[0, 0, 0]])
+ # part_cd, part_f = compute_cd_and_f_score_in_training(
+ # gt_part_surface, pred_part_mesh,
+ # num_samples=configs['val']['metric']['cd_num_samples'],
+ # threshold=configs['val']['metric']['f1_score_threshold'],
+ # metric=configs['val']['metric']['cd_metric']
+ # )
+ # # avoid nan
+ # part_cd = configs['val']['metric']['default_cd'] if np.isnan(part_cd) else part_cd
+ # part_f = configs['val']['metric']['default_f1'] if np.isnan(part_f) else part_f
+ # parts_chamfer_distances.append(part_cd)
+ # parts_f_scores.append(part_f)
+
+ # TODO: Fix this
+ # Disable chamfer distance and F1 score for now
+ parts_chamfer_distances.append(0.0)
+ parts_f_scores.append(0.0)
+
+ parts_chamfer_distances = torch.tensor(parts_chamfer_distances, device=accelerator.device)
+ parts_f_scores = torch.tensor(parts_f_scores, device=accelerator.device)
+
+ metrics_dictlist[f"parts_chamfer_distance_cfg{guidance_scale:.1f}"].append(parts_chamfer_distances.mean())
+ metrics_dictlist[f"parts_f_score_cfg{guidance_scale:.1f}"].append(parts_f_scores.mean())
+
+ # Only log the last (biggest) cfg metrics in the progress bar
+ val_logs = {
+ "parts_chamfer_distance": parts_chamfer_distances.mean().item(),
+ "parts_f_score": parts_f_scores.mean().item(),
+ }
+ val_progress_bar.set_postfix(**val_logs)
+ logger.info(
+ f"Validation [{val_step:02d}/{args.max_val_steps:02d}] " +
+ f"parts_chamfer_distance: {val_logs['parts_chamfer_distance']:.4f}, parts_f_score: {val_logs['parts_f_score']:.4f}"
+ )
+ logger.info(
+ f"parts_chamfer_distances: {[f'{x:.4f}' for x in parts_chamfer_distances.tolist()]}"
+ )
+ logger.info(
+ f"parts_f_scores: {[f'{x:.4f}' for x in parts_f_scores.tolist()]}"
+ )
+ val_step += 1
+ val_progress_bar.update(1)
+
+ val_progress_bar.close()
+
+ if accelerator.is_main_process:
+ for key, value in medias_dictlist.items():
+ if isinstance(value[0], Image.Image): # assuming gt_image
+ image_grid = make_grid_for_images_or_videos(
+ value,
+ nrow=configs['val']['nrow'],
+ return_type='pil',
+ )
+ image_grid.save(os.path.join(eval_dir, f"{global_step:06d}", f"{key}.png"))
+ wandb.log({f"validation/{key}": wandb.Image(image_grid)}, step=global_step)
+ else: # assuming pred_rendered_images or pred_rendered_normals
+ image_grids = make_grid_for_images_or_videos(
+ value,
+ nrow=configs['val']['nrow'],
+ return_type='ndarray',
+ )
+ wandb.log({
+ f"validation/{key}": wandb.Video(
+ image_grids,
+ fps=configs['val']['rendering']['fps'],
+ format="gif"
+ )}, step=global_step)
+ image_grids = [Image.fromarray(image_grid.transpose(1, 2, 0)) for image_grid in image_grids]
+ export_renderings(
+ image_grids,
+ os.path.join(eval_dir, f"{global_step:06d}", f"{key}.gif"),
+ fps=configs['val']['rendering']['fps']
+ )
+
+ for k, v in metrics_dictlist.items():
+ wandb.log({f"validation/{k}": torch.tensor(v).mean().item()}, step=global_step)
+
+if __name__ == "__main__":
+ main()
diff --git a/src/utils/data_utils.py b/src/utils/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..229e6376dc524b04b2e5f4336d7f2e8fb716c2db
--- /dev/null
+++ b/src/utils/data_utils.py
@@ -0,0 +1,191 @@
+from src.utils.typing_utils import *
+
+import os
+import numpy as np
+import trimesh
+import torch
+
+def normalize_mesh(
+ mesh: Union[trimesh.Trimesh, trimesh.Scene],
+ scale: float = 2.0,
+):
+ # if not isinstance(mesh, trimesh.Trimesh) and not isinstance(mesh, trimesh.Scene):
+ # raise ValueError("Input mesh is not a trimesh.Trimesh or trimesh.Scene object.")
+ bbox = mesh.bounding_box
+ translation = -bbox.centroid
+ scale = scale / bbox.primitive.extents.max()
+ mesh.apply_translation(translation)
+ mesh.apply_scale(scale)
+ return mesh
+
+def remove_overlapping_vertices(mesh: trimesh.Trimesh, reserve_material: bool = False):
+ if not isinstance(mesh, trimesh.Trimesh):
+ raise ValueError("Input mesh is not a trimesh.Trimesh object.")
+ vertices = mesh.vertices
+ faces = mesh.faces
+ unique_vertices, index_map, inverse_map = np.unique(
+ vertices, axis=0, return_index=True, return_inverse=True
+ )
+ clean_faces = inverse_map[faces]
+ clean_mesh = trimesh.Trimesh(vertices=unique_vertices, faces=clean_faces, process=True)
+ if reserve_material:
+ uv = mesh.visual.uv
+ material = mesh.visual.material
+ clean_uv = uv[index_map]
+ clean_visual = trimesh.visual.TextureVisuals(uv=clean_uv, material=material)
+ clean_mesh.visual = clean_visual
+ return clean_mesh
+
+RGB = [
+ (82, 170, 220),
+ (215, 91, 78),
+ (45, 136, 117),
+ (247, 172, 83),
+ (124, 121, 121),
+ (127, 171, 209),
+ (243, 152, 101),
+ (145, 204, 192),
+ (150, 59, 121),
+ (181, 206, 78),
+ (189, 119, 149),
+ (199, 193, 222),
+ (200, 151, 54),
+ (236, 110, 102),
+ (238, 182, 212),
+]
+
+
+def get_colored_mesh_composition(
+ meshes: Union[List[trimesh.Trimesh], trimesh.Scene],
+ is_random: bool = True,
+ is_sorted: bool = False,
+ RGB: List[Tuple] = RGB
+):
+ if isinstance(meshes, trimesh.Scene):
+ meshes = meshes.dump()
+ if is_sorted:
+ volumes = []
+ for mesh in meshes:
+ try:
+ volume = mesh.volume
+ except:
+ volume = 0.0
+ volumes.append(volume)
+ # sort by volume from large to small
+ meshes = [x for _, x in sorted(zip(volumes, meshes), key=lambda pair: pair[0], reverse=True)]
+ colored_scene = trimesh.Scene()
+ for idx, mesh in enumerate(meshes):
+ if is_random:
+ color = (np.random.rand(3) * 256).astype(int)
+ else:
+ color = np.array(RGB[idx % len(RGB)])
+ mesh.visual = trimesh.visual.ColorVisuals(
+ mesh=mesh,
+ vertex_colors=color,
+ )
+ colored_scene.add_geometry(mesh)
+ return colored_scene
+
+def mesh_to_surface(
+ mesh: trimesh.Trimesh,
+ num_pc: int = 204800,
+ clip_to_num_vertices: bool = False,
+ return_dict: bool = False,
+):
+ # if not isinstance(mesh, trimesh.Trimesh):
+ # raise ValueError("mesh must be a trimesh.Trimesh object")
+ if clip_to_num_vertices:
+ num_pc = min(num_pc, mesh.vertices.shape[0])
+ points, face_indices = mesh.sample(num_pc, return_index=True)
+ normals = mesh.face_normals[face_indices]
+ if return_dict:
+ return {
+ "surface_points": points,
+ "surface_normals": normals,
+ }
+ return points, normals
+
+def scene_to_parts(
+ mesh: trimesh.Scene,
+ normalize: bool = True,
+ scale: float = 2.0,
+ num_part_pc: int = 204800,
+ clip_to_num_part_vertices: bool = False,
+ return_type: Literal["mesh", "point"] = "mesh",
+) -> Union[List[trimesh.Geometry], List[Dict[str, np.ndarray]]]:
+ if not isinstance(mesh, trimesh.Scene):
+ raise ValueError("mesh must be a trimesh.Scene object")
+ if normalize:
+ mesh = normalize_mesh(mesh, scale=scale)
+ parts: List[trimesh.Geometry] = mesh.dump()
+ if return_type == "point":
+ datas: List[Dict[str, np.ndarray]] = []
+ for geom in parts:
+ data = mesh_to_surface(
+ geom,
+ num_pc=num_part_pc,
+ clip_to_num_vertices=clip_to_num_part_vertices,
+ return_dict=True,
+ )
+ datas.append(data)
+ return datas
+ elif return_type == "mesh":
+ return parts
+ else:
+ raise ValueError("return_type must be 'mesh' or 'point'")
+
+def get_center(mesh: trimesh.Trimesh, method: Literal['mass', 'bbox']):
+ if method == 'mass':
+ return mesh.center_mass
+ elif method =='bbox':
+ return mesh.bounding_box.centroid
+ else:
+ raise ValueError('type must be mass or bbox')
+
+def get_direction(vector: np.ndarray):
+ return vector / np.linalg.norm(vector)
+
+def move_mesh_by_center(mesh: trimesh.Trimesh, scale: float, method: Literal['mass', 'bbox'] = 'mass'):
+ offset = scale - 1
+ center = get_center(mesh, method)
+ direction = get_direction(center)
+ translation = direction * offset
+ mesh = mesh.copy()
+ mesh.apply_translation(translation)
+ return mesh
+
+def move_meshes_by_center(meshes: Union[List[trimesh.Trimesh], trimesh.Scene], scale: float):
+ if isinstance(meshes, trimesh.Scene):
+ meshes = meshes.dump()
+ moved_meshes = []
+ for mesh in meshes:
+ moved_mesh = move_mesh_by_center(mesh, scale)
+ moved_meshes.append(moved_mesh)
+ moved_meshes = trimesh.Scene(moved_meshes)
+ return moved_meshes
+
+def get_series_splited_meshes(meshes: List[trimesh.Trimesh], scale: float, num_steps: int) -> List[trimesh.Scene]:
+ series_meshes = []
+ for i in range(num_steps):
+ temp_scale = 1 + (scale - 1) * i / (num_steps - 1)
+ temp_meshes = move_meshes_by_center(meshes, temp_scale)
+ series_meshes.append(temp_meshes)
+ return series_meshes
+
+def load_surface(data, num_pc=204800):
+
+ surface = data["surface_points"] # Nx3
+ normal = data["surface_normals"] # Nx3
+
+ rng = np.random.default_rng()
+ ind = rng.choice(surface.shape[0], num_pc, replace=False)
+ surface = torch.FloatTensor(surface[ind])
+ normal = torch.FloatTensor(normal[ind])
+ surface = torch.cat([surface, normal], dim=-1)
+
+ return surface
+
+def load_surfaces(surfaces, num_pc=204800):
+ surfaces = [load_surface(surface, num_pc) for surface in surfaces]
+ surfaces = torch.stack(surfaces, dim=0)
+ return surfaces
\ No newline at end of file
diff --git a/src/utils/image_utils.py b/src/utils/image_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6e0e0209fd3afecc73d67368715021059ee8a64
--- /dev/null
+++ b/src/utils/image_utils.py
@@ -0,0 +1,151 @@
+# -*- coding: utf-8 -*-
+import os
+from skimage.morphology import remove_small_objects
+from skimage.measure import label
+import numpy as np
+from PIL import Image
+import cv2
+from torchvision import transforms
+import torch
+import torch.nn.functional as F
+import torchvision.transforms.functional as TF
+
+def find_bounding_box(gray_image):
+ _, binary_image = cv2.threshold(gray_image, 1, 255, cv2.THRESH_BINARY)
+ contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ max_contour = max(contours, key=cv2.contourArea)
+ x, y, w, h = cv2.boundingRect(max_contour)
+ return x, y, w, h
+
+def load_image(img_path, bg_color=None, rmbg_net=None, padding_ratio=0.1, device='cuda'):
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
+ if img is None:
+ return f"invalid image path {img_path}"
+
+ def is_valid_alpha(alpha, min_ratio = 0.01):
+ bins = 20
+ if isinstance(alpha, np.ndarray):
+ hist = cv2.calcHist([alpha], [0], None, [bins], [0, 256])
+ else:
+ hist = torch.histc(alpha, bins=bins, min=0, max=1)
+ min_hist_val = alpha.shape[0] * alpha.shape[1] * min_ratio
+ return hist[0] >= min_hist_val and hist[-1] >= min_hist_val
+
+ def rmbg(image: torch.Tensor) -> torch.Tensor:
+ image = TF.normalize(image, [0.5,0.5,0.5], [1.0,1.0,1.0]).unsqueeze(0)
+ result=rmbg_net(image)
+ return result[0][0]
+
+ if len(img.shape) == 2:
+ num_channels = 1
+ else:
+ num_channels = img.shape[2]
+
+ # check if too large
+ height, width = img.shape[:2]
+ if height > width:
+ scale = 2000 / height
+ else:
+ scale = 2000 / width
+ if scale < 1:
+ new_size = (int(width * scale), int(height * scale))
+ img = cv2.resize(img, new_size, interpolation=cv2.INTER_AREA)
+
+ if img.dtype != 'uint8':
+ img = (img * (255. / np.iinfo(img.dtype).max)).astype(np.uint8)
+
+ rgb_image = None
+ alpha = None
+
+ if num_channels == 1:
+ rgb_image = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
+ elif num_channels == 3:
+ rgb_image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ elif num_channels == 4:
+ rgb_image = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
+
+ b, g, r, alpha = cv2.split(img)
+ if not is_valid_alpha(alpha):
+ alpha = None
+ else:
+ alpha_gpu = torch.from_numpy(alpha).unsqueeze(0).to(device).float() / 255.
+ else:
+ return f"invalid image: channels {num_channels}"
+
+ rgb_image_gpu = torch.from_numpy(rgb_image).to(device).float().permute(2, 0, 1) / 255.
+ if alpha is None:
+ resize_transform = transforms.Resize((384, 384), antialias=True)
+ rgb_image_resized = resize_transform(rgb_image_gpu)
+ normalize_image = rgb_image_resized * 2 - 1
+
+ mean_color = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(device)
+ resize_transform = transforms.Resize((1024, 1024), antialias=True)
+ rgb_image_resized = resize_transform(rgb_image_gpu)
+ max_value = rgb_image_resized.flatten().max()
+ if max_value < 1e-3:
+ return "invalid image: pure black image"
+ normalize_image = rgb_image_resized / max_value - mean_color
+ normalize_image = normalize_image.unsqueeze(0)
+ resize_transform = transforms.Resize((rgb_image_gpu.shape[1], rgb_image_gpu.shape[2]), antialias=True)
+
+ # seg from rmbg
+ alpha_gpu_rmbg = rmbg(rgb_image_resized)
+ alpha_gpu_rmbg = alpha_gpu_rmbg.squeeze(0)
+ alpha_gpu_rmbg = resize_transform(alpha_gpu_rmbg)
+ ma, mi = alpha_gpu_rmbg.max(), alpha_gpu_rmbg.min()
+ alpha_gpu_rmbg = (alpha_gpu_rmbg - mi) / (ma - mi)
+
+ alpha_gpu = alpha_gpu_rmbg
+
+ alpha_gpu_tmp = alpha_gpu * 255
+ alpha = alpha_gpu_tmp.to(torch.uint8).squeeze().cpu().numpy()
+
+ _, alpha = cv2.threshold(alpha, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
+ labeled_alpha = label(alpha)
+ cleaned_alpha = remove_small_objects(labeled_alpha, min_size=200)
+ cleaned_alpha = (cleaned_alpha > 0).astype(np.uint8)
+ alpha = cleaned_alpha * 255
+ alpha_gpu = torch.from_numpy(cleaned_alpha).to(device).float().unsqueeze(0)
+ x, y, w, h = find_bounding_box(alpha)
+
+ # If alpha is provided, the bounds of all foreground are used
+ else:
+ rows, cols = np.where(alpha > 0)
+ if rows.size > 0 and cols.size > 0:
+ x_min = np.min(cols)
+ y_min = np.min(rows)
+ x_max = np.max(cols)
+ y_max = np.max(rows)
+
+ width = x_max - x_min + 1
+ height = y_max - y_min + 1
+ x, y, w, h = x_min, y_min, width, height
+
+ if np.all(alpha==0):
+ raise ValueError(f"input image too small")
+
+ bg_gray = bg_color[0]
+ bg_color = torch.from_numpy(bg_color).float().to(device).repeat(alpha_gpu.shape[1], alpha_gpu.shape[2], 1).permute(2, 0, 1)
+ rgb_image_gpu = rgb_image_gpu * alpha_gpu + bg_color * (1 - alpha_gpu)
+ padding_size = [0] * 6
+ if w > h:
+ padding_size[0] = int(w * padding_ratio)
+ padding_size[2] = int(padding_size[0] + (w - h) / 2)
+ else:
+ padding_size[2] = int(h * padding_ratio)
+ padding_size[0] = int(padding_size[2] + (h - w) / 2)
+ padding_size[1] = padding_size[0]
+ padding_size[3] = padding_size[2]
+ padded_tensor = F.pad(rgb_image_gpu[:, y:(y+h), x:(x+w)], pad=tuple(padding_size), mode='constant', value=bg_gray)
+
+ return padded_tensor
+
+def prepare_image(image_path, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=None, padding_ratio=0.1, device='cuda'):
+ if os.path.isfile(image_path):
+ img_tensor = load_image(image_path, bg_color=bg_color, rmbg_net=rmbg_net, padding_ratio=padding_ratio, device=device)
+ img_np = img_tensor.permute(1,2,0).cpu().numpy()
+ img_pil = Image.fromarray((img_np*255).astype(np.uint8))
+
+ return img_pil
+ else:
+ raise ValueError(f"Invalid image path: {image_path}")
\ No newline at end of file
diff --git a/src/utils/inference_utils.py b/src/utils/inference_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..90465195857caa7df41c05693a2b57d209b39f27
--- /dev/null
+++ b/src/utils/inference_utils.py
@@ -0,0 +1,507 @@
+from src.utils.typing_utils import *
+
+import numpy as np
+import torch
+import torch.nn as nn
+import scipy.ndimage
+from skimage import measure
+from einops import repeat
+from diso import DiffDMC
+import torch.nn.functional as F
+
+def generate_dense_grid_points(
+ bbox_min: np.ndarray, bbox_max: np.ndarray, octree_depth: int, indexing: str = "ij"
+):
+ length = bbox_max - bbox_min
+ num_cells = np.exp2(octree_depth)
+ x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
+ y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
+ z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
+ [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
+ xyz = np.stack((xs, ys, zs), axis=-1)
+ xyz = xyz.reshape(-1, 3)
+ grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
+
+ return xyz, grid_size, length
+
+def generate_dense_grid_points_gpu(
+ bbox_min: torch.Tensor,
+ bbox_max: torch.Tensor,
+ octree_depth: int,
+ indexing: str = "ij",
+ dtype: torch.dtype = torch.float16
+):
+ length = bbox_max - bbox_min
+ num_cells = 2 ** octree_depth
+ device = bbox_min.device
+
+ x = torch.linspace(bbox_min[0], bbox_max[0], int(num_cells), dtype=dtype, device=device)
+ y = torch.linspace(bbox_min[1], bbox_max[1], int(num_cells), dtype=dtype, device=device)
+ z = torch.linspace(bbox_min[2], bbox_max[2], int(num_cells), dtype=dtype, device=device)
+
+ xs, ys, zs = torch.meshgrid(x, y, z, indexing=indexing)
+ xyz = torch.stack((xs, ys, zs), dim=-1)
+ xyz = xyz.view(-1, 3)
+ grid_size = [int(num_cells), int(num_cells), int(num_cells)]
+
+ return xyz, grid_size, length
+
+def find_mesh_grid_coordinates_fast_gpu(
+ occupancy_grid,
+ n_limits=-1
+):
+ core_grid = occupancy_grid[1:-1, 1:-1, 1:-1]
+ occupied = core_grid > 0
+
+ neighbors_unoccupied = (
+ (occupancy_grid[:-2, :-2, :-2] < 0)
+ | (occupancy_grid[:-2, :-2, 1:-1] < 0)
+ | (occupancy_grid[:-2, :-2, 2:] < 0) # x-1, y-1, z-1/0/1
+ | (occupancy_grid[:-2, 1:-1, :-2] < 0)
+ | (occupancy_grid[:-2, 1:-1, 1:-1] < 0)
+ | (occupancy_grid[:-2, 1:-1, 2:] < 0) # x-1, y0, z-1/0/1
+ | (occupancy_grid[:-2, 2:, :-2] < 0)
+ | (occupancy_grid[:-2, 2:, 1:-1] < 0)
+ | (occupancy_grid[:-2, 2:, 2:] < 0) # x-1, y+1, z-1/0/1
+ | (occupancy_grid[1:-1, :-2, :-2] < 0)
+ | (occupancy_grid[1:-1, :-2, 1:-1] < 0)
+ | (occupancy_grid[1:-1, :-2, 2:] < 0) # x0, y-1, z-1/0/1
+ | (occupancy_grid[1:-1, 1:-1, :-2] < 0)
+ | (occupancy_grid[1:-1, 1:-1, 2:] < 0) # x0, y0, z-1/1
+ | (occupancy_grid[1:-1, 2:, :-2] < 0)
+ | (occupancy_grid[1:-1, 2:, 1:-1] < 0)
+ | (occupancy_grid[1:-1, 2:, 2:] < 0) # x0, y+1, z-1/0/1
+ | (occupancy_grid[2:, :-2, :-2] < 0)
+ | (occupancy_grid[2:, :-2, 1:-1] < 0)
+ | (occupancy_grid[2:, :-2, 2:] < 0) # x+1, y-1, z-1/0/1
+ | (occupancy_grid[2:, 1:-1, :-2] < 0)
+ | (occupancy_grid[2:, 1:-1, 1:-1] < 0)
+ | (occupancy_grid[2:, 1:-1, 2:] < 0) # x+1, y0, z-1/0/1
+ | (occupancy_grid[2:, 2:, :-2] < 0)
+ | (occupancy_grid[2:, 2:, 1:-1] < 0)
+ | (occupancy_grid[2:, 2:, 2:] < 0) # x+1, y+1, z-1/0/1
+ )
+ core_mesh_coords = torch.nonzero(occupied & neighbors_unoccupied, as_tuple=False) + 1
+
+ if n_limits != -1 and core_mesh_coords.shape[0] > n_limits:
+ print(f"core mesh coords {core_mesh_coords.shape[0]} is too large, limited to {n_limits}")
+ ind = np.random.choice(core_mesh_coords.shape[0], n_limits, True)
+ core_mesh_coords = core_mesh_coords[ind]
+
+ return core_mesh_coords
+
+def find_candidates_band(
+ occupancy_grid: torch.Tensor,
+ band_threshold: float,
+ n_limits: int = -1
+) -> torch.Tensor:
+ """
+ Returns the coordinates of all voxels in the occupancy_grid where |value| < band_threshold.
+
+ Args:
+ occupancy_grid (torch.Tensor): A 3D tensor of SDF values.
+ band_threshold (float): The threshold below which |SDF| must be to include the voxel.
+ n_limits (int): Maximum number of points to return (-1 for no limit)
+
+ Returns:
+ torch.Tensor: A 2D tensor of coordinates (N x 3) where each row is [x, y, z].
+ """
+ core_grid = occupancy_grid[1:-1, 1:-1, 1:-1]
+ # logits to sdf
+ core_grid = torch.sigmoid(core_grid) * 2 - 1
+ # Create a boolean mask for all cells in the band
+ in_band = torch.abs(core_grid) < band_threshold
+
+ # Get coordinates of all voxels in the band
+ core_mesh_coords = torch.nonzero(in_band, as_tuple=False) + 1
+
+ if n_limits != -1 and core_mesh_coords.shape[0] > n_limits:
+ print(f"core mesh coords {core_mesh_coords.shape[0]} is too large, limited to {n_limits}")
+ ind = np.random.choice(core_mesh_coords.shape[0], n_limits, True)
+ core_mesh_coords = core_mesh_coords[ind]
+
+ return core_mesh_coords
+
+def expand_edge_region_fast(edge_coords, grid_size, dtype):
+ expanded_tensor = torch.zeros(grid_size, grid_size, grid_size, device='cuda', dtype=dtype, requires_grad=False)
+ expanded_tensor[edge_coords[:, 0], edge_coords[:, 1], edge_coords[:, 2]] = 1
+ if grid_size < 512:
+ kernel_size = 5
+ pooled_tensor = torch.nn.functional.max_pool3d(expanded_tensor.unsqueeze(0).unsqueeze(0), kernel_size=kernel_size, stride=1, padding=2).squeeze()
+ else:
+ kernel_size = 3
+ pooled_tensor = torch.nn.functional.max_pool3d(expanded_tensor.unsqueeze(0).unsqueeze(0), kernel_size=kernel_size, stride=1, padding=1).squeeze()
+ expanded_coords_low_res = torch.nonzero(pooled_tensor, as_tuple=False).to(torch.int16)
+
+ expanded_coords_high_res = torch.stack([
+ torch.cat((expanded_coords_low_res[:, 0] * 2, expanded_coords_low_res[:, 0] * 2, expanded_coords_low_res[:, 0] * 2, expanded_coords_low_res[:, 0] * 2, expanded_coords_low_res[:, 0] * 2 + 1, expanded_coords_low_res[:, 0] * 2 + 1, expanded_coords_low_res[:, 0] * 2 + 1, expanded_coords_low_res[:, 0] * 2 + 1)),
+ torch.cat((expanded_coords_low_res[:, 1] * 2, expanded_coords_low_res[:, 1] * 2, expanded_coords_low_res[:, 1] * 2+1, expanded_coords_low_res[:, 1] * 2 + 1, expanded_coords_low_res[:, 1] * 2, expanded_coords_low_res[:, 1] * 2, expanded_coords_low_res[:, 1] * 2 + 1, expanded_coords_low_res[:, 1] * 2 + 1)),
+ torch.cat((expanded_coords_low_res[:, 2] * 2, expanded_coords_low_res[:, 2] * 2+1, expanded_coords_low_res[:, 2] * 2, expanded_coords_low_res[:, 2] * 2 + 1, expanded_coords_low_res[:, 2] * 2, expanded_coords_low_res[:, 2] * 2+1, expanded_coords_low_res[:, 2] * 2, expanded_coords_low_res[:, 2] * 2 + 1))
+ ], dim=1)
+
+ return expanded_coords_high_res
+
+def zoom_block(block, scale_factor, order=3):
+ block = block.astype(np.float32)
+ return scipy.ndimage.zoom(block, scale_factor, order=order)
+
+def parallel_zoom(occupancy_grid, scale_factor):
+ result = torch.nn.functional.interpolate(occupancy_grid.unsqueeze(0).unsqueeze(0), scale_factor=scale_factor)
+ return result.squeeze(0).squeeze(0)
+
+
+@torch.no_grad()
+def hierarchical_extract_geometry(
+ geometric_func: Callable,
+ device: torch.device,
+ dtype: torch.dtype,
+ bounds: Union[Tuple[float], List[float], float] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
+ dense_octree_depth: int = 8,
+ hierarchical_octree_depth: int = 9,
+ max_num_expanded_coords: int = 1e8,
+ verbose: bool = False,
+):
+ """
+ Args:
+ geometric_func:
+ device:
+ bounds:
+ dense_octree_depth:
+ hierarchical_octree_depth:
+ Returns:
+ """
+ if isinstance(bounds, float):
+ bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
+
+ bbox_min = torch.tensor(bounds[0:3]).to(device)
+ bbox_max = torch.tensor(bounds[3:6]).to(device)
+ bbox_size = bbox_max - bbox_min
+
+ xyz_samples, grid_size, length = generate_dense_grid_points_gpu(
+ bbox_min=bbox_min,
+ bbox_max=bbox_max,
+ octree_depth=dense_octree_depth,
+ indexing="ij",
+ dtype=dtype
+ )
+
+ if verbose:
+ print(f'step 1 query num: {xyz_samples.shape[0]}')
+ grid_logits = geometric_func(xyz_samples.unsqueeze(0)).to(dtype).view(grid_size[0], grid_size[1], grid_size[2])
+ # print(f'step 1 grid_logits shape: {grid_logits.shape}')
+ for i in range(hierarchical_octree_depth - dense_octree_depth):
+ curr_octree_depth = dense_octree_depth + i + 1
+ # upsample
+ grid_size = 2**curr_octree_depth
+ normalize_offset = grid_size / 2
+ high_res_occupancy = parallel_zoom(grid_logits, 2).to(dtype)
+
+ band_threshold = 1.0
+ edge_coords = find_candidates_band(grid_logits, band_threshold)
+ expanded_coords = expand_edge_region_fast(edge_coords, grid_size=int(grid_size/2), dtype=dtype).to(dtype)
+ if verbose:
+ print(f'step {i+2} query num: {len(expanded_coords)}')
+ if max_num_expanded_coords > 0 and len(expanded_coords) > max_num_expanded_coords:
+ raise ValueError(f"expanded_coords is too large, {len(expanded_coords)} > {max_num_expanded_coords}")
+ expanded_coords_norm = (expanded_coords - normalize_offset) * (abs(bounds[0]) / normalize_offset)
+
+ all_logits = None
+
+ all_logits = geometric_func(expanded_coords_norm.unsqueeze(0)).to(dtype)
+ all_logits = torch.cat([expanded_coords_norm, all_logits[0]], dim=1)
+ # print("all logits shape = ", all_logits.shape)
+
+ indices = all_logits[..., :3]
+ indices = indices * (normalize_offset / abs(bounds[0])) + normalize_offset
+ indices = indices.type(torch.IntTensor)
+ values = all_logits[:, 3]
+ # breakpoint()
+ high_res_occupancy[indices[:, 0], indices[:, 1], indices[:, 2]] = values
+ grid_logits = high_res_occupancy
+ # torch.cuda.empty_cache()
+
+ if verbose:
+ print("final grids shape = ", grid_logits.shape)
+ vertices, faces, normals, _ = measure.marching_cubes(grid_logits.float().cpu().numpy(), 0, method="lewiner")
+ vertices = vertices / (2**hierarchical_octree_depth) * bbox_size.cpu().numpy() + bbox_min.cpu().numpy()
+ mesh_v_f = (vertices.astype(np.float32), np.ascontiguousarray(faces))
+
+ return mesh_v_f
+
+def extract_near_surface_volume_fn(input_tensor: torch.Tensor, alpha: float):
+ """
+ Args:
+ input_tensor: shape [D, D, D], torch.float16
+ alpha: isosurface offset
+ Returns:
+ mask: shape [D, D, D], torch.int32
+ """
+ device = input_tensor.device
+ D = input_tensor.shape[0]
+ signed_val = 0.0
+
+ # add isosurface offset and exclude invalid value
+ val = input_tensor + alpha
+ valid_mask = val > -9000
+
+ # obtain neighbors
+ def get_neighbor(t, shift, axis):
+ if shift == 0:
+ return t.clone()
+
+ pad_dims = [0, 0, 0, 0, 0, 0] # [x_front,x_back,y_front,y_back,z_front,z_back]
+
+ if axis == 0: # x axis
+ pad_idx = 0 if shift > 0 else 1
+ pad_dims[pad_idx] = abs(shift)
+ elif axis == 1: # y axis
+ pad_idx = 2 if shift > 0 else 3
+ pad_dims[pad_idx] = abs(shift)
+ elif axis == 2: # z axis
+ pad_idx = 4 if shift > 0 else 5
+ pad_dims[pad_idx] = abs(shift)
+
+ # Apply padding with replication at boundaries
+ padded = F.pad(t.unsqueeze(0).unsqueeze(0), pad_dims[::-1], mode='replicate')
+
+ # Create dynamic slicing indices
+ slice_dims = [slice(None)] * 3
+ if axis == 0: # x axis
+ if shift > 0:
+ slice_dims[0] = slice(shift, None)
+ else:
+ slice_dims[0] = slice(None, shift)
+ elif axis == 1: # y axis
+ if shift > 0:
+ slice_dims[1] = slice(shift, None)
+ else:
+ slice_dims[1] = slice(None, shift)
+ elif axis == 2: # z axis
+ if shift > 0:
+ slice_dims[2] = slice(shift, None)
+ else:
+ slice_dims[2] = slice(None, shift)
+
+ # Apply slicing and restore dimensions
+ padded = padded.squeeze(0).squeeze(0)
+ sliced = padded[slice_dims]
+ return sliced
+
+ # Get neighbors in all directions
+ left = get_neighbor(val, 1, axis=0) # x axis
+ right = get_neighbor(val, -1, axis=0)
+ back = get_neighbor(val, 1, axis=1) # y axis
+ front = get_neighbor(val, -1, axis=1)
+ down = get_neighbor(val, 1, axis=2) # z axis
+ up = get_neighbor(val, -1, axis=2)
+
+ # Handle invalid boundary values
+ def safe_where(neighbor):
+ return torch.where(neighbor > -9000, neighbor, val)
+
+ left = safe_where(left)
+ right = safe_where(right)
+ back = safe_where(back)
+ front = safe_where(front)
+ down = safe_where(down)
+ up = safe_where(up)
+
+ # Calculate sign consistency
+ sign = torch.sign(val.to(torch.float32))
+ neighbors_sign = torch.stack([
+ torch.sign(left.to(torch.float32)),
+ torch.sign(right.to(torch.float32)),
+ torch.sign(back.to(torch.float32)),
+ torch.sign(front.to(torch.float32)),
+ torch.sign(down.to(torch.float32)),
+ torch.sign(up.to(torch.float32))
+ ], dim=0)
+
+ # Check if all signs are consistent
+ same_sign = torch.all(neighbors_sign == sign, dim=0)
+
+ # Generate final mask
+ mask = (~same_sign).to(torch.int32)
+ return mask * valid_mask.to(torch.int32)
+
+
+def generate_dense_grid_points_2(
+ bbox_min: np.ndarray,
+ bbox_max: np.ndarray,
+ octree_resolution: int,
+ indexing: str = "ij",
+):
+ length = bbox_max - bbox_min
+ num_cells = octree_resolution
+
+ x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
+ y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
+ z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
+ [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
+ xyz = np.stack((xs, ys, zs), axis=-1)
+ grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
+
+ return xyz, grid_size, length
+
+@torch.no_grad()
+def flash_extract_geometry(
+ latents: torch.FloatTensor,
+ vae: Callable,
+ bounds: Union[Tuple[float], List[float], float] = 1.01,
+ num_chunks: int = 10000,
+ mc_level: float = 0.0,
+ octree_depth: int = 9,
+ min_resolution: int = 63,
+ mini_grid_num: int = 4,
+ **kwargs,
+):
+ geo_decoder = vae.decoder
+ device = latents.device
+ dtype = latents.dtype
+ # resolution to depth
+ octree_resolution = 2 ** octree_depth
+ resolutions = []
+ if octree_resolution < min_resolution:
+ resolutions.append(octree_resolution)
+ while octree_resolution >= min_resolution:
+ resolutions.append(octree_resolution)
+ octree_resolution = octree_resolution // 2
+ resolutions.reverse()
+ resolutions[0] = round(resolutions[0] / mini_grid_num) * mini_grid_num - 1
+ for i, resolution in enumerate(resolutions[1:]):
+ resolutions[i + 1] = resolutions[0] * 2 ** (i + 1)
+
+
+ # 1. generate query points
+ if isinstance(bounds, float):
+ bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
+ bbox_min = np.array(bounds[0:3])
+ bbox_max = np.array(bounds[3:6])
+ bbox_size = bbox_max - bbox_min
+
+ xyz_samples, grid_size, length = generate_dense_grid_points_2(
+ bbox_min=bbox_min,
+ bbox_max=bbox_max,
+ octree_resolution=resolutions[0],
+ indexing="ij"
+ )
+
+ dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)
+ dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device))
+
+ grid_size = np.array(grid_size)
+
+ # 2. latents to 3d volume
+ xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype)
+ batch_size = latents.shape[0]
+ mini_grid_size = xyz_samples.shape[0] // mini_grid_num
+ xyz_samples = xyz_samples.view(
+ mini_grid_num, mini_grid_size,
+ mini_grid_num, mini_grid_size,
+ mini_grid_num, mini_grid_size, 3
+ ).permute(
+ 0, 2, 4, 1, 3, 5, 6
+ ).reshape(
+ -1, mini_grid_size * mini_grid_size * mini_grid_size, 3
+ )
+ batch_logits = []
+ num_batchs = max(num_chunks // xyz_samples.shape[1], 1)
+ for start in range(0, xyz_samples.shape[0], num_batchs):
+ queries = xyz_samples[start: start + num_batchs, :]
+ batch = queries.shape[0]
+ batch_latents = repeat(latents.squeeze(0), "p c -> b p c", b=batch)
+ # geo_decoder.set_topk(True)
+ geo_decoder.set_topk(False)
+ logits = vae.decode(batch_latents, queries).sample
+ batch_logits.append(logits)
+ grid_logits = torch.cat(batch_logits, dim=0).reshape(
+ mini_grid_num, mini_grid_num, mini_grid_num,
+ mini_grid_size, mini_grid_size,
+ mini_grid_size
+ ).permute(0, 3, 1, 4, 2, 5).contiguous().view(
+ (batch_size, grid_size[0], grid_size[1], grid_size[2])
+ )
+
+ for octree_depth_now in resolutions[1:]:
+ grid_size = np.array([octree_depth_now + 1] * 3)
+ resolution = bbox_size / octree_depth_now
+ next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)
+ next_logits = torch.full(next_index.shape, -10000., dtype=dtype, device=device)
+ curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level)
+ curr_points += grid_logits.squeeze(0).abs() < 0.95
+
+ if octree_depth_now == resolutions[-1]:
+ expand_num = 0
+ else:
+ expand_num = 1
+ for i in range(expand_num):
+ curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
+ curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
+ (cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)
+
+ next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
+ for i in range(2 - expand_num):
+ next_index = dilate(next_index.unsqueeze(0)).squeeze(0)
+ nidx = torch.where(next_index > 0)
+
+ next_points = torch.stack(nidx, dim=1)
+ next_points = (next_points * torch.tensor(resolution, dtype=torch.float32, device=device) +
+ torch.tensor(bbox_min, dtype=torch.float32, device=device))
+
+ query_grid_num = 6
+ min_val = next_points.min(axis=0).values
+ max_val = next_points.max(axis=0).values
+ vol_queries_index = (next_points - min_val) / (max_val - min_val) * (query_grid_num - 0.001)
+ index = torch.floor(vol_queries_index).long()
+ index = index[..., 0] * (query_grid_num ** 2) + index[..., 1] * query_grid_num + index[..., 2]
+ index = index.sort()
+ next_points = next_points[index.indices].unsqueeze(0).contiguous()
+ unique_values = torch.unique(index.values, return_counts=True)
+ grid_logits = torch.zeros((next_points.shape[1]), dtype=latents.dtype, device=latents.device)
+ input_grid = [[], []]
+ logits_grid_list = []
+ start_num = 0
+ sum_num = 0
+ for grid_index, count in zip(unique_values[0].cpu().tolist(), unique_values[1].cpu().tolist()):
+ if sum_num + count < num_chunks or sum_num == 0:
+ sum_num += count
+ input_grid[0].append(grid_index)
+ input_grid[1].append(count)
+ else:
+ # geo_decoder.set_topk(input_grid)
+ geo_decoder.set_topk(False)
+ logits_grid = vae.decode(latents,next_points[:, start_num:start_num + sum_num]).sample
+ start_num = start_num + sum_num
+ logits_grid_list.append(logits_grid)
+ input_grid = [[grid_index], [count]]
+ sum_num = count
+ if sum_num > 0:
+ # geo_decoder.set_topk(input_grid)
+ geo_decoder.set_topk(False)
+ logits_grid = vae.decode(latents,next_points[:, start_num:start_num + sum_num]).sample
+ logits_grid_list.append(logits_grid)
+ logits_grid = torch.cat(logits_grid_list, dim=1)
+ grid_logits[index.indices] = logits_grid.squeeze(0).squeeze(-1)
+ next_logits[nidx] = grid_logits
+ grid_logits = next_logits.unsqueeze(0)
+
+ grid_logits[grid_logits == -10000.] = float('nan')
+ torch.cuda.empty_cache()
+ mesh_v_f = []
+ grid_logits = grid_logits[0]
+ try:
+ print("final grids shape = ", grid_logits.shape)
+ dmc = DiffDMC(dtype=torch.float32).to(grid_logits.device)
+ sdf = -grid_logits / octree_resolution
+ sdf = sdf.to(torch.float32).contiguous()
+ vertices, faces = dmc(sdf, deform=None, return_quads=False, normalize=False)
+ vertices = vertices.detach().cpu().numpy()
+ faces = faces.detach().cpu().numpy()[:, ::-1]
+ vertices = vertices / (2 ** octree_depth) * bbox_size + bbox_min
+ mesh_v_f = (vertices.astype(np.float32), np.ascontiguousarray(faces))
+ except Exception as e:
+ print(e)
+ torch.cuda.empty_cache()
+ mesh_v_f = (None, None)
+
+ return [mesh_v_f]
diff --git a/src/utils/metric_utils.py b/src/utils/metric_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4945423334264e044a0f1e87483ab7fed19df4b
--- /dev/null
+++ b/src/utils/metric_utils.py
@@ -0,0 +1,159 @@
+from src.utils.typing_utils import *
+
+import trimesh
+import numpy as np
+from sklearn.neighbors import NearestNeighbors
+
+def sample_from_mesh(
+ mesh: trimesh.Trimesh,
+ num_samples: Optional[int] = 10000,
+):
+ if num_samples is None:
+ return mesh.vertices
+ else:
+ return mesh.sample(num_samples)
+
+def sample_two_meshes(
+ mesh1: trimesh.Trimesh,
+ mesh2: trimesh.Trimesh,
+ num_samples: Optional[int] = 10000,
+):
+ points1 = sample_from_mesh(mesh1, num_samples)
+ points2 = sample_from_mesh(mesh2, num_samples)
+ return points1, points2
+
+def compute_nearest_distance(
+ points1: np.ndarray,
+ points2: np.ndarray,
+ metric: str = 'l2'
+) -> np.ndarray:
+ # Compute nearest neighbor distance from points1 to points2
+ nn = NearestNeighbors(n_neighbors=1, leaf_size=30, algorithm='kd_tree', metric=metric).fit(points2)
+ min_dist = nn.kneighbors(points1)[0]
+ return min_dist
+
+def compute_mutual_nearest_distance(
+ points1: np.ndarray,
+ points2: np.ndarray,
+ metric: str = 'l2'
+) -> np.ndarray:
+ min_1_to_2 = compute_nearest_distance(points1, points2, metric=metric)
+ min_2_to_1 = compute_nearest_distance(points2, points1, metric=metric)
+ return min_1_to_2, min_2_to_1
+
+def compute_mutual_nearest_distance_for_meshes(
+ mesh1: trimesh.Trimesh,
+ mesh2: trimesh.Trimesh,
+ num_samples: Optional[int] = 10000,
+ metric: str = 'l2'
+) -> Tuple[np.ndarray, np.ndarray]:
+ points1 = sample_from_mesh(mesh1, num_samples)
+ points2 = sample_from_mesh(mesh2, num_samples)
+ min_1_to_2, min_2_to_1 = compute_mutual_nearest_distance(points1, points2, metric=metric)
+ return min_1_to_2, min_2_to_1
+
+def compute_chamfer_distance(
+ mesh1: trimesh.Trimesh,
+ mesh2: trimesh.Trimesh,
+ num_samples: int = 10000,
+ metric: str = 'l2'
+):
+ min_1_to_2, min_2_to_1 = compute_mutual_nearest_distance_for_meshes(mesh1, mesh2, num_samples, metric=metric)
+ chamfer_dist = np.mean(min_2_to_1) + np.mean(min_1_to_2)
+ return chamfer_dist
+
+def compute_f_score(
+ mesh1: trimesh.Trimesh,
+ mesh2: trimesh.Trimesh,
+ num_samples: int = 10000,
+ threshold: float = 0.1,
+ metric: str = 'l2'
+):
+ min_1_to_2, min_2_to_1 = compute_mutual_nearest_distance_for_meshes(mesh1, mesh2, num_samples, metric=metric)
+ precision_1 = np.mean((min_1_to_2 < threshold).astype(np.float32))
+ precision_2 = np.mean((min_2_to_1 < threshold).astype(np.float32))
+ fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2)
+ return fscore
+
+def compute_cd_and_f_score(
+ mesh1: trimesh.Trimesh,
+ mesh2: trimesh.Trimesh,
+ num_samples: Optional[int] = 10000,
+ threshold: float = 0.1,
+ metric: str = 'l2'
+):
+ min_1_to_2, min_2_to_1 = compute_mutual_nearest_distance_for_meshes(mesh1, mesh2, num_samples, metric=metric)
+ chamfer_dist = np.mean(min_2_to_1) + np.mean(min_1_to_2)
+ precision_1 = np.mean((min_1_to_2 < threshold).astype(np.float32))
+ precision_2 = np.mean((min_2_to_1 < threshold).astype(np.float32))
+ fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2)
+ return chamfer_dist, fscore
+
+def compute_cd_and_f_score_in_training(
+ gt_surface: np.ndarray,
+ pred_mesh: trimesh.Trimesh,
+ num_samples: int = 204800,
+ threshold: float = 0.1,
+ metric: str = 'l2'
+):
+ gt_points = gt_surface[:, :3]
+ num_samples = max(num_samples, gt_points.shape[0])
+ gt_points = gt_points[np.random.choice(gt_points.shape[0], num_samples, replace=False)]
+ pred_points = sample_from_mesh(pred_mesh, num_samples)
+ min_1_to_2, min_2_to_1 = compute_mutual_nearest_distance(gt_points, pred_points, metric=metric)
+ chamfer_dist = np.mean(min_2_to_1) + np.mean(min_1_to_2)
+ precision_1 = np.mean((min_1_to_2 < threshold).astype(np.float32))
+ precision_2 = np.mean((min_2_to_1 < threshold).astype(np.float32))
+ fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2)
+ return chamfer_dist, fscore
+
+def get_voxel_set(
+ mesh: trimesh.Trimesh,
+ num_grids: int = 64,
+ scale: float = 2.0,
+):
+ if not isinstance(mesh, trimesh.Trimesh):
+ raise ValueError("mesh must be a trimesh.Trimesh object")
+ pitch = scale / num_grids
+ voxel_girds: trimesh.voxel.base.VoxelGrid = mesh.voxelized(pitch=pitch).fill()
+ voxels = set(map(tuple, np.round(voxel_girds.points / pitch).astype(int)))
+ return voxels
+
+def compute_IoU(
+ mesh1: trimesh.Trimesh,
+ mesh2: trimesh.Trimesh,
+ num_grids: int = 64,
+ scale: float = 2.0,
+):
+ if not isinstance(mesh1, trimesh.Trimesh) or not isinstance(mesh2, trimesh.Trimesh):
+ raise ValueError("mesh1 and mesh2 must be trimesh.Trimesh objects")
+ voxels1 = get_voxel_set(mesh1, num_grids, scale)
+ voxels2 = get_voxel_set(mesh2, num_grids, scale)
+ intersection = voxels1 & voxels2
+ union = voxels1 | voxels2
+ iou = len(intersection) / len(union) if len(union) > 0 else 0.0
+ return iou
+
+def compute_IoU_for_scene(
+ scene: Union[trimesh.Scene, List[trimesh.Trimesh]],
+ num_grids: int = 64,
+ scale: float = 2.0,
+ return_type: Literal["iou", "iou_list"] = "iou",
+):
+ if isinstance(scene, trimesh.Scene):
+ scene = scene.dump()
+ if isinstance(scene, list) and len(scene) > 1 and isinstance(scene[0], trimesh.Trimesh):
+ meshes = scene
+ else:
+ raise ValueError("scene must be a trimesh.Scene object or a list of trimesh.Trimesh objects")
+ ious = []
+ for i in range(len(meshes)):
+ for j in range(i+1, len(meshes)):
+ iou = compute_IoU(meshes[i], meshes[j], num_grids, scale)
+ ious.append(iou)
+ if return_type == "iou":
+ return np.mean(ious)
+ elif return_type == "iou_list":
+ return ious
+ else:
+ raise ValueError("return_type must be 'iou' or 'iou_list'")
\ No newline at end of file
diff --git a/src/utils/render_utils.py b/src/utils/render_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6eba89f28d0a036d1e8902bcadab1f8667126e3
--- /dev/null
+++ b/src/utils/render_utils.py
@@ -0,0 +1,411 @@
+from src.utils.typing_utils import *
+
+import os
+import numpy as np
+from PIL import Image
+import trimesh
+from trimesh.transformations import rotation_matrix
+import pyrender
+from diffusers.utils import export_to_video
+from diffusers.utils.loading_utils import load_video
+import torch
+from torchvision.utils import make_grid
+
+os.environ['PYOPENGL_PLATFORM'] = 'egl'
+
+def render(
+ scene: pyrender.Scene,
+ renderer: pyrender.Renderer,
+ camera: pyrender.Camera,
+ pose: np.ndarray,
+ light: Optional[pyrender.Light] = None,
+ normalize_depth: bool = False,
+ flags: int = pyrender.constants.RenderFlags.NONE,
+ return_type: Literal['pil', 'ndarray'] = 'pil'
+) -> Union[Tuple[np.ndarray, np.ndarray], Tuple[Image.Image, Image.Image]]:
+ camera_node = scene.add(camera, pose=pose)
+ if light is not None:
+ light_node = scene.add(light, pose=pose)
+ image, depth = renderer.render(
+ scene,
+ flags=flags
+ )
+ scene.remove_node(camera_node)
+ if light is not None:
+ scene.remove_node(light_node)
+ if normalize_depth or return_type == 'pil':
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
+ if return_type == 'pil':
+ image = Image.fromarray(image)
+ depth = Image.fromarray(depth.astype(np.uint8))
+ return image, depth
+
+def rotation_matrix_from_vectors(vec1, vec2):
+ a, b = vec1 / np.linalg.norm(vec1), vec2 / np.linalg.norm(vec2)
+ v = np.cross(a, b)
+ c = np.dot(a, b)
+ s = np.linalg.norm(v)
+ if s == 0:
+ return np.eye(3) if c > 0 else -np.eye(3)
+ kmat = np.array([
+ [0, -v[2], v[1]],
+ [v[2], 0, -v[0]],
+ [-v[1], v[0], 0]
+ ])
+ return np.eye(3) + kmat + kmat @ kmat * ((1 - c) / (s ** 2))
+
+def create_circular_camera_positions(
+ num_views: int,
+ radius: float,
+ axis: np.ndarray = np.array([0.0, 1.0, 0.0])
+) -> List[np.ndarray]:
+ # Create a list of positions for a circular camera trajectory
+ # around the given axis with the given radius.
+ positions = []
+ axis = axis / np.linalg.norm(axis)
+ for i in range(num_views):
+ theta = 2 * np.pi * i / num_views
+ position = np.array([
+ np.sin(theta) * radius,
+ 0.0,
+ np.cos(theta) * radius
+ ])
+ if not np.allclose(axis, np.array([0.0, 1.0, 0.0])):
+ R = rotation_matrix_from_vectors(np.array([0.0, 1.0, 0.0]), axis)
+ position = R @ position
+ positions.append(position)
+ return positions
+
+def create_circular_camera_poses(
+ num_views: int,
+ radius: float,
+ axis: np.ndarray = np.array([0.0, 1.0, 0.0])
+) -> List[np.ndarray]:
+ # Create a list of poses for a circular camera trajectory
+ # around the given axis with the given radius.
+ # The camera always looks at the origin.
+ # The up vector is always [0, 1, 0].
+ canonical_pose = np.array([
+ [1.0, 0.0, 0.0, 0.0],
+ [0.0, 1.0, 0.0, 0.0],
+ [0.0, 0.0, 1.0, radius],
+ [0.0, 0.0, 0.0, 1.0]
+ ])
+ poses = []
+ for i in range(num_views):
+ theta = 2 * np.pi * i / num_views
+ R = rotation_matrix(
+ angle=theta,
+ direction=axis,
+ point=[0, 0, 0]
+ )
+ pose = R @ canonical_pose
+ poses.append(pose)
+ return poses
+
+def render_views_around_mesh(
+ mesh: Union[trimesh.Trimesh, trimesh.Scene],
+ num_views: int = 36,
+ radius: float = 3.5,
+ axis: np.ndarray = np.array([0.0, 1.0, 0.0]),
+ image_size: tuple = (512, 512),
+ fov: float = 40.0,
+ light_intensity: Optional[float] = 5.0,
+ znear: float = 0.1,
+ zfar: float = 10.0,
+ normalize_depth: bool = False,
+ flags: int = pyrender.constants.RenderFlags.NONE,
+ return_depth: bool = False,
+ return_type: Literal['pil', 'ndarray'] = 'pil'
+) -> Union[
+ List[Image.Image],
+ List[np.ndarray],
+ Tuple[List[Image.Image], List[Image.Image]],
+ Tuple[List[np.ndarray], List[np.ndarray]]
+ ]:
+
+ if not isinstance(mesh, (trimesh.Trimesh, trimesh.Scene)):
+ raise ValueError("mesh must be a trimesh.Trimesh or trimesh.Scene object")
+ if isinstance(mesh, trimesh.Trimesh):
+ mesh = trimesh.Scene(mesh)
+
+ scene = pyrender.Scene.from_trimesh_scene(mesh)
+ light = pyrender.DirectionalLight(
+ color=np.ones(3),
+ intensity=light_intensity
+ ) if light_intensity is not None else None
+ camera = pyrender.PerspectiveCamera(
+ yfov=np.deg2rad(fov),
+ aspectRatio=image_size[0]/image_size[1],
+ znear=znear,
+ zfar=zfar
+ )
+ renderer = pyrender.OffscreenRenderer(*image_size)
+
+ camera_poses = create_circular_camera_poses(
+ num_views,
+ radius,
+ axis = axis
+ )
+
+ images, depths = [], []
+ for pose in camera_poses:
+ image, depth = render(
+ scene, renderer, camera, pose, light,
+ normalize_depth=normalize_depth,
+ flags=flags,
+ return_type=return_type
+ )
+ images.append(image)
+ depths.append(depth)
+
+ renderer.delete()
+
+ if return_depth:
+ return images, depths
+ return images
+
+def render_normal_views_around_mesh(
+ mesh: Union[trimesh.Trimesh, trimesh.Scene],
+ num_views: int = 36,
+ radius: float = 3.5,
+ axis: np.ndarray = np.array([0.0, 1.0, 0.0]),
+ image_size: tuple = (512, 512),
+ fov: float = 40.0,
+ light_intensity: Optional[float] = 5.0,
+ znear: float = 0.1,
+ zfar: float = 10.0,
+ normalize_depth: bool = False,
+ flags: int = pyrender.constants.RenderFlags.NONE,
+ return_depth: bool = False,
+ return_type: Literal['pil', 'ndarray'] = 'pil'
+) -> Union[
+ List[Image.Image],
+ List[np.ndarray],
+ Tuple[List[Image.Image], List[Image.Image]],
+ Tuple[List[np.ndarray], List[np.ndarray]]
+ ]:
+
+ if not isinstance(mesh, (trimesh.Trimesh, trimesh.Scene)):
+ raise ValueError("mesh must be a trimesh.Trimesh or trimesh.Scene object")
+ if isinstance(mesh, trimesh.Scene):
+ mesh = mesh.to_geometry()
+ normals = mesh.vertex_normals
+ colors = ((normals + 1.0) / 2.0 * 255).astype(np.uint8)
+ mesh.visual = trimesh.visual.ColorVisuals(
+ mesh=mesh,
+ vertex_colors=colors
+ )
+ mesh = trimesh.Scene(mesh)
+ return render_views_around_mesh(
+ mesh, num_views, radius, axis,
+ image_size, fov, light_intensity, znear, zfar,
+ normalize_depth, flags,
+ return_depth, return_type
+ )
+
+def create_camera_pose_on_sphere(
+ azimuth: float = 0.0, # in degrees
+ elevation: float = 0.0, # in degrees
+ radius: float = 3.5,
+) -> np.ndarray:
+ # Create a camera pose for a given azimuth and elevation
+ # with the given radius.
+ # The camera always looks at the origin.
+ # The up vector is always [0, 1, 0].
+ canonical_pose = np.array([
+ [1.0, 0.0, 0.0, 0.0],
+ [0.0, 1.0, 0.0, 0.0],
+ [0.0, 0.0, 1.0, radius],
+ [0.0, 0.0, 0.0, 1.0]
+ ])
+ azimuth = np.deg2rad(azimuth)
+ elevation = np.deg2rad(elevation)
+ position = np.array([
+ np.cos(elevation) * np.sin(azimuth),
+ np.sin(elevation),
+ np.cos(elevation) * np.cos(azimuth),
+ ])
+ R = np.eye(4)
+ R[:3, :3] = rotation_matrix_from_vectors(
+ np.array([0.0, 0.0, 1.0]),
+ position
+ )
+ pose = R @ canonical_pose
+ return pose
+
+def render_single_view(
+ mesh: Union[trimesh.Trimesh, trimesh.Scene],
+ azimuth: float = 0.0, # in degrees
+ elevation: float = 0.0, # in degrees
+ radius: float = 3.5,
+ image_size: tuple = (512, 512),
+ fov: float = 40.0,
+ light_intensity: Optional[float] = 5.0,
+ num_env_lights: int = 0,
+ znear: float = 0.1,
+ zfar: float = 10.0,
+ normalize_depth: bool = False,
+ flags: int = pyrender.constants.RenderFlags.NONE,
+ return_depth: bool = False,
+ return_type: Literal['pil', 'ndarray'] = 'pil'
+) -> Union[
+ Image.Image,
+ np.ndarray,
+ Tuple[Image.Image, Image.Image],
+ Tuple[np.ndarray, np.ndarray]
+ ]:
+
+ if not isinstance(mesh, (trimesh.Trimesh, trimesh.Scene)):
+ raise ValueError("mesh must be a trimesh.Trimesh or trimesh.Scene object")
+ if isinstance(mesh, trimesh.Trimesh):
+ mesh = trimesh.Scene(mesh)
+
+ scene = pyrender.Scene.from_trimesh_scene(mesh)
+ light = pyrender.DirectionalLight(
+ color=np.ones(3),
+ intensity=light_intensity
+ ) if light_intensity is not None else None
+ camera = pyrender.PerspectiveCamera(
+ yfov=np.deg2rad(fov),
+ aspectRatio=image_size[0]/image_size[1],
+ znear=znear,
+ zfar=zfar
+ )
+ renderer = pyrender.OffscreenRenderer(*image_size)
+
+ camera_pose = create_camera_pose_on_sphere(
+ azimuth,
+ elevation,
+ radius
+ )
+
+ if num_env_lights > 0:
+ env_light_poses = create_circular_camera_poses(
+ num_env_lights,
+ radius,
+ axis = np.array([0.0, 1.0, 0.0])
+ )
+ for pose in env_light_poses:
+ scene.add(pyrender.DirectionalLight(
+ color=np.ones(3),
+ intensity=light_intensity
+ ), pose=pose)
+ # set light to None
+ light = None
+
+ image, depth = render(
+ scene, renderer, camera, camera_pose, light,
+ normalize_depth=normalize_depth,
+ flags=flags,
+ return_type=return_type
+ )
+ renderer.delete()
+
+ if return_depth:
+ return image, depth
+ return image
+
+def render_normal_single_view(
+ mesh: Union[trimesh.Trimesh, trimesh.Scene],
+ azimuth: float = 0.0, # in degrees
+ elevation: float = 0.0, # in degrees
+ radius: float = 3.5,
+ image_size: tuple = (512, 512),
+ fov: float = 40.0,
+ light_intensity: Optional[float] = 5.0,
+ znear: float = 0.1,
+ zfar: float = 10.0,
+ normalize_depth: bool = False,
+ flags: int = pyrender.constants.RenderFlags.NONE,
+ return_depth: bool = False,
+ return_type: Literal['pil', 'ndarray'] = 'pil'
+) -> Union[
+ Image.Image,
+ np.ndarray,
+ Tuple[Image.Image, Image.Image],
+ Tuple[np.ndarray, np.ndarray]
+ ]:
+
+ if not isinstance(mesh, (trimesh.Trimesh, trimesh.Scene)):
+ raise ValueError("mesh must be a trimesh.Trimesh or trimesh.Scene object")
+ if isinstance(mesh, trimesh.Scene):
+ mesh = mesh.to_geometry()
+ normals = mesh.vertex_normals
+ colors = ((normals + 1.0) / 2.0 * 255).astype(np.uint8)
+ mesh.visual = trimesh.visual.ColorVisuals(
+ mesh=mesh,
+ vertex_colors=colors
+ )
+ mesh = trimesh.Scene(mesh)
+ return render_single_view(
+ mesh, azimuth, elevation, radius,
+ image_size, fov, light_intensity, znear, zfar,
+ normalize_depth, flags,
+ return_depth, return_type
+ )
+
+def export_renderings(
+ images: List[Image.Image],
+ export_path: str,
+ fps: int = 36,
+ loop: int = 0
+):
+ export_type = export_path.split('.')[-1]
+ if export_type == 'mp4':
+ export_to_video(
+ images,
+ export_path,
+ fps=fps,
+ )
+ elif export_type == 'gif':
+ duration = 1000 / fps
+ images[0].save(
+ export_path,
+ save_all=True,
+ append_images=images[1:],
+ duration=duration,
+ loop=loop
+ )
+ else:
+ raise ValueError(f'Unknown export type: {export_type}')
+
+def make_grid_for_images_or_videos(
+ images_or_videos: Union[List[Image.Image], List[List[Image.Image]]],
+ nrow: int = 4,
+ padding: int = 0,
+ pad_value: int = 0,
+ image_size: tuple = (512, 512),
+ return_type: Literal['pil', 'ndarray'] = 'pil'
+) -> Union[Image.Image, List[Image.Image], np.ndarray]:
+ if isinstance(images_or_videos[0], Image.Image):
+ images = [np.array(image.resize(image_size).convert('RGB')) for image in images_or_videos]
+ images = np.stack(images, axis=0).transpose(0, 3, 1, 2) # [N, C, H, W]
+ images = torch.from_numpy(images)
+ image_grid = make_grid(
+ images,
+ nrow=nrow,
+ padding=padding,
+ pad_value=pad_value,
+ normalize=False
+ ) # [C, H', W']
+ image_grid = image_grid.cpu().numpy()
+ if return_type == 'pil':
+ image_grid = Image.fromarray(image_grid.transpose(1, 2, 0))
+ return image_grid
+ elif isinstance(images_or_videos[0], list) and isinstance(images_or_videos[0][0], Image.Image):
+ image_grids = []
+ for i in range(len(images_or_videos[0])):
+ images = [video[i] for video in images_or_videos]
+ image_grid = make_grid_for_images_or_videos(
+ images,
+ nrow=nrow,
+ padding=padding,
+ return_type=return_type
+ )
+ image_grids.append(image_grid)
+ if return_type == 'ndarray':
+ image_grids = np.stack(image_grids, axis=0)
+ return image_grids
+ else:
+ raise ValueError(f'Unknown input type: {type(images_or_videos[0])}')
\ No newline at end of file
diff --git a/src/utils/smoothing.py b/src/utils/smoothing.py
new file mode 100644
index 0000000000000000000000000000000000000000..046098f63bc1c6beadebf08d8d07de74a05fbb08
--- /dev/null
+++ b/src/utils/smoothing.py
@@ -0,0 +1,643 @@
+# -*- coding: utf-8 -*-
+
+# Copyright (c) 2012-2015, P. M. Neila
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# * Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""
+Utilities for smoothing the occ/sdf grids.
+"""
+
+import logging
+from typing import Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from scipy import ndimage as ndi
+from scipy import sparse
+
+__all__ = [
+ "smooth",
+ "smooth_constrained",
+ "smooth_gaussian",
+ "signed_distance_function",
+ "smooth_gpu",
+ "smooth_constrained_gpu",
+ "smooth_gaussian_gpu",
+ "signed_distance_function_gpu",
+]
+
+
+def _build_variable_indices(band: np.ndarray) -> np.ndarray:
+ num_variables = np.count_nonzero(band)
+ variable_indices = np.full(band.shape, -1, dtype=np.int_)
+ variable_indices[band] = np.arange(num_variables)
+ return variable_indices
+
+
+def _buildq3d(variable_indices: np.ndarray):
+ """
+ Builds the filterq matrix for the given variables.
+ """
+
+ num_variables = variable_indices.max() + 1
+ filterq = sparse.lil_matrix((3 * num_variables, num_variables))
+
+ # Pad variable_indices to simplify out-of-bounds accesses
+ variable_indices = np.pad(
+ variable_indices, [(0, 1), (0, 1), (0, 1)], mode="constant", constant_values=-1
+ )
+
+ coords = np.nonzero(variable_indices >= 0)
+ for count, (i, j, k) in enumerate(zip(*coords)):
+
+ assert variable_indices[i, j, k] == count
+
+ filterq[3 * count, count] = -2
+ neighbor = variable_indices[i - 1, j, k]
+ if neighbor >= 0:
+ filterq[3 * count, neighbor] = 1
+ else:
+ filterq[3 * count, count] += 1
+
+ neighbor = variable_indices[i + 1, j, k]
+ if neighbor >= 0:
+ filterq[3 * count, neighbor] = 1
+ else:
+ filterq[3 * count, count] += 1
+
+ filterq[3 * count + 1, count] = -2
+ neighbor = variable_indices[i, j - 1, k]
+ if neighbor >= 0:
+ filterq[3 * count + 1, neighbor] = 1
+ else:
+ filterq[3 * count + 1, count] += 1
+
+ neighbor = variable_indices[i, j + 1, k]
+ if neighbor >= 0:
+ filterq[3 * count + 1, neighbor] = 1
+ else:
+ filterq[3 * count + 1, count] += 1
+
+ filterq[3 * count + 2, count] = -2
+ neighbor = variable_indices[i, j, k - 1]
+ if neighbor >= 0:
+ filterq[3 * count + 2, neighbor] = 1
+ else:
+ filterq[3 * count + 2, count] += 1
+
+ neighbor = variable_indices[i, j, k + 1]
+ if neighbor >= 0:
+ filterq[3 * count + 2, neighbor] = 1
+ else:
+ filterq[3 * count + 2, count] += 1
+
+ filterq = filterq.tocsr()
+ return filterq.T.dot(filterq)
+
+
+def _buildq3d_gpu(variable_indices: torch.Tensor, chunk_size=10000):
+ """
+ Builds the filterq matrix for the given variables on GPU, using chunking to reduce memory usage.
+ """
+ device = variable_indices.device
+ num_variables = variable_indices.max().item() + 1
+
+ # Pad variable_indices to simplify out-of-bounds accesses
+ variable_indices = torch.nn.functional.pad(
+ variable_indices, (0, 1, 0, 1, 0, 1), mode="constant", value=-1
+ )
+
+ coords = torch.nonzero(variable_indices >= 0)
+ i, j, k = coords[:, 0], coords[:, 1], coords[:, 2]
+
+ # Function to process a chunk of data
+ def process_chunk(start, end):
+ row_indices = []
+ col_indices = []
+ values = []
+
+ for axis in range(3):
+ row_indices.append(3 * torch.arange(start, end, device=device) + axis)
+ col_indices.append(
+ variable_indices[i[start:end], j[start:end], k[start:end]]
+ )
+ values.append(torch.full((end - start,), -2, device=device))
+
+ for offset in [-1, 1]:
+ if axis == 0:
+ neighbor = variable_indices[
+ i[start:end] + offset, j[start:end], k[start:end]
+ ]
+ elif axis == 1:
+ neighbor = variable_indices[
+ i[start:end], j[start:end] + offset, k[start:end]
+ ]
+ else:
+ neighbor = variable_indices[
+ i[start:end], j[start:end], k[start:end] + offset
+ ]
+
+ mask = neighbor >= 0
+ row_indices.append(
+ 3 * torch.arange(start, end, device=device)[mask] + axis
+ )
+ col_indices.append(neighbor[mask])
+ values.append(torch.ones(mask.sum(), device=device))
+
+ # Add 1 to the diagonal for out-of-bounds neighbors
+ row_indices.append(
+ 3 * torch.arange(start, end, device=device)[~mask] + axis
+ )
+ col_indices.append(
+ variable_indices[i[start:end], j[start:end], k[start:end]][~mask]
+ )
+ values.append(torch.ones((~mask).sum(), device=device))
+
+ return torch.cat(row_indices), torch.cat(col_indices), torch.cat(values)
+
+ # Process data in chunks
+ all_row_indices = []
+ all_col_indices = []
+ all_values = []
+
+ for start in range(0, coords.shape[0], chunk_size):
+ end = min(start + chunk_size, coords.shape[0])
+ row_indices, col_indices, values = process_chunk(start, end)
+ all_row_indices.append(row_indices)
+ all_col_indices.append(col_indices)
+ all_values.append(values)
+
+ # Concatenate all chunks
+ row_indices = torch.cat(all_row_indices)
+ col_indices = torch.cat(all_col_indices)
+ values = torch.cat(all_values)
+
+ # Create sparse tensor
+ indices = torch.stack([row_indices, col_indices])
+ filterq = torch.sparse_coo_tensor(
+ indices, values, (3 * num_variables, num_variables)
+ )
+
+ # Compute filterq.T @ filterq
+ return torch.sparse.mm(filterq.t(), filterq)
+
+
+# Usage example:
+# variable_indices = torch.tensor(...).cuda() # Your input tensor on GPU
+# result = _buildq3d_gpu(variable_indices)
+
+
+def _buildq2d(variable_indices: np.ndarray):
+ """
+ Builds the filterq matrix for the given variables.
+
+ Version for 2 dimensions.
+ """
+
+ num_variables = variable_indices.max() + 1
+ filterq = sparse.lil_matrix((3 * num_variables, num_variables))
+
+ # Pad variable_indices to simplify out-of-bounds accesses
+ variable_indices = np.pad(
+ variable_indices, [(0, 1), (0, 1)], mode="constant", constant_values=-1
+ )
+
+ coords = np.nonzero(variable_indices >= 0)
+ for count, (i, j) in enumerate(zip(*coords)):
+ assert variable_indices[i, j] == count
+
+ filterq[2 * count, count] = -2
+ neighbor = variable_indices[i - 1, j]
+ if neighbor >= 0:
+ filterq[2 * count, neighbor] = 1
+ else:
+ filterq[2 * count, count] += 1
+
+ neighbor = variable_indices[i + 1, j]
+ if neighbor >= 0:
+ filterq[2 * count, neighbor] = 1
+ else:
+ filterq[2 * count, count] += 1
+
+ filterq[2 * count + 1, count] = -2
+ neighbor = variable_indices[i, j - 1]
+ if neighbor >= 0:
+ filterq[2 * count + 1, neighbor] = 1
+ else:
+ filterq[2 * count + 1, count] += 1
+
+ neighbor = variable_indices[i, j + 1]
+ if neighbor >= 0:
+ filterq[2 * count + 1, neighbor] = 1
+ else:
+ filterq[2 * count + 1, count] += 1
+
+ filterq = filterq.tocsr()
+ return filterq.T.dot(filterq)
+
+
+def _jacobi(
+ filterq,
+ x0: np.ndarray,
+ lower_bound: np.ndarray,
+ upper_bound: np.ndarray,
+ max_iters: int = 10,
+ rel_tol: float = 1e-6,
+ weight: float = 0.5,
+):
+ """Jacobi method with constraints."""
+
+ jacobi_r = sparse.lil_matrix(filterq)
+ shp = jacobi_r.shape
+ jacobi_d = 1.0 / filterq.diagonal()
+ jacobi_r.setdiag((0,) * shp[0])
+ jacobi_r = jacobi_r.tocsr()
+
+ x = x0
+
+ # We check the stopping criterion each 10 iterations
+ check_each = 10
+ cum_rel_tol = 1 - (1 - rel_tol) ** check_each
+
+ energy_now = np.dot(x, filterq.dot(x)) / 2
+ logging.info("Energy at iter %d: %.6g", 0, energy_now)
+ for i in range(max_iters):
+
+ x_1 = -jacobi_d * jacobi_r.dot(x)
+ x = weight * x_1 + (1 - weight) * x
+
+ # Constraints.
+ x = np.maximum(x, lower_bound)
+ x = np.minimum(x, upper_bound)
+
+ # Stopping criterion
+ if (i + 1) % check_each == 0:
+ # Update energy
+ energy_before = energy_now
+ energy_now = np.dot(x, filterq.dot(x)) / 2
+
+ logging.info("Energy at iter %d: %.6g", i + 1, energy_now)
+
+ # Check stopping criterion
+ cum_rel_improvement = (energy_before - energy_now) / energy_before
+ if cum_rel_improvement < cum_rel_tol:
+ break
+
+ return x
+
+
+def signed_distance_function(
+ levelset: np.ndarray, band_radius: int
+) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Return the distance to the 0.5 levelset of a function, the mask of the
+ border (i.e., the nearest cells to the 0.5 level-set) and the mask of the
+ band (i.e., the cells of the function whose distance to the 0.5 level-set
+ is less of equal to `band_radius`).
+ """
+
+ binary_array = np.where(levelset > 0, True, False)
+
+ # Compute the band and the border.
+ dist_func = ndi.distance_transform_edt
+ distance = np.where(
+ binary_array, dist_func(binary_array) - 0.5, -dist_func(~binary_array) + 0.5
+ )
+ border = np.abs(distance) < 1
+ band = np.abs(distance) <= band_radius
+
+ return distance, border, band
+
+
+def signed_distance_function_iso0(
+ levelset: np.ndarray, band_radius: int
+) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Return the distance to the 0 levelset of a function, the mask of the
+ border (i.e., the nearest cells to the 0 level-set) and the mask of the
+ band (i.e., the cells of the function whose distance to the 0 level-set
+ is less of equal to `band_radius`).
+ """
+
+ binary_array = levelset > 0
+
+ # Compute the band and the border.
+ dist_func = ndi.distance_transform_edt
+ distance = np.where(
+ binary_array, dist_func(binary_array), -dist_func(~binary_array)
+ )
+ border = np.zeros_like(levelset, dtype=bool)
+ border[:-1, :, :] |= levelset[:-1, :, :] * levelset[1:, :, :] <= 0
+ border[:, :-1, :] |= levelset[:, :-1, :] * levelset[:, 1:, :] <= 0
+ border[:, :, :-1] |= levelset[:, :, :-1] * levelset[:, :, 1:] <= 0
+ band = np.abs(distance) <= band_radius
+
+ return distance, border, band
+
+
+def signed_distance_function_gpu(levelset: torch.Tensor, band_radius: int):
+ binary_array = (levelset > 0).float()
+
+ # Compute distance transform
+ dist_pos = (
+ F.max_pool3d(
+ -binary_array.unsqueeze(0).unsqueeze(0), kernel_size=3, stride=1, padding=1
+ )
+ .squeeze(0)
+ .squeeze(0)
+ + binary_array
+ )
+ dist_neg = F.max_pool3d(
+ (binary_array - 1).unsqueeze(0).unsqueeze(0), kernel_size=3, stride=1, padding=1
+ ).squeeze(0).squeeze(0) + (1 - binary_array)
+
+ distance = torch.where(binary_array > 0, dist_pos - 0.5, -dist_neg + 0.5)
+
+ # breakpoint()
+
+ # Use levelset as distance directly
+ # distance = levelset
+ # print(distance.shape)
+ # Compute border and band
+ border = torch.abs(distance) < 1
+ band = torch.abs(distance) <= band_radius
+
+ return distance, border, band
+
+
+def smooth_constrained(
+ binary_array: np.ndarray,
+ band_radius: int = 4,
+ max_iters: int = 250,
+ rel_tol: float = 1e-6,
+) -> np.ndarray:
+ """
+ Implementation of the smoothing method from
+
+ "Surface Extraction from Binary Volumes with Higher-Order Smoothness"
+ Victor Lempitsky, CVPR10
+ """
+
+ # # Compute the distance map, the border and the band.
+ logging.info("Computing distance transform...")
+ # distance, _, band = signed_distance_function(binary_array, band_radius)
+ binary_array_gpu = torch.from_numpy(binary_array).cuda()
+ distance, _, band = signed_distance_function_gpu(binary_array_gpu, band_radius)
+ distance = distance.cpu().numpy()
+ band = band.cpu().numpy()
+
+ variable_indices = _build_variable_indices(band)
+
+ # Compute filterq.
+ logging.info("Building matrix filterq...")
+ if binary_array.ndim == 3:
+ filterq = _buildq3d(variable_indices)
+ # variable_indices_gpu = torch.from_numpy(variable_indices).cuda()
+ # filterq_gpu = _buildq3d_gpu(variable_indices_gpu)
+ # filterq = filterq_gpu.cpu().numpy()
+ elif binary_array.ndim == 2:
+ filterq = _buildq2d(variable_indices)
+ else:
+ raise ValueError("binary_array.ndim not in [2, 3]")
+
+ # Initialize the variables.
+ res = np.asarray(distance, dtype=np.double)
+ x = res[band]
+ upper_bound = np.where(x < 0, x, np.inf)
+ lower_bound = np.where(x > 0, x, -np.inf)
+
+ upper_bound[np.abs(upper_bound) < 1] = 0
+ lower_bound[np.abs(lower_bound) < 1] = 0
+
+ # Solve.
+ logging.info("Minimizing energy...")
+ x = _jacobi(
+ filterq=filterq,
+ x0=x,
+ lower_bound=lower_bound,
+ upper_bound=upper_bound,
+ max_iters=max_iters,
+ rel_tol=rel_tol,
+ )
+
+ res[band] = x
+ return res
+
+
+def total_variation_denoising(x, weight=0.1, num_iterations=5, eps=1e-8):
+ diff_x = torch.diff(x, dim=0, prepend=x[:1])
+ diff_y = torch.diff(x, dim=1, prepend=x[:, :1])
+ diff_z = torch.diff(x, dim=2, prepend=x[:, :, :1])
+
+ norm = torch.sqrt(diff_x**2 + diff_y**2 + diff_z**2 + eps)
+
+ div_x = torch.diff(diff_x / norm, dim=0, append=diff_x[-1:] / norm[-1:])
+ div_y = torch.diff(diff_y / norm, dim=1, append=diff_y[:, -1:] / norm[:, -1:])
+ div_z = torch.diff(diff_z / norm, dim=2, append=diff_z[:, :, -1:] / norm[:, :, -1:])
+
+ return x - weight * (div_x + div_y + div_z)
+
+
+def smooth_constrained_gpu(
+ binary_array: torch.Tensor,
+ band_radius: int = 4,
+ max_iters: int = 250,
+ rel_tol: float = 1e-4,
+):
+ distance, _, band = signed_distance_function_gpu(binary_array, band_radius)
+
+ # Initialize variables
+ x = distance[band]
+ upper_bound = torch.where(x < 0, x, torch.tensor(float("inf"), device=x.device))
+ lower_bound = torch.where(x > 0, x, torch.tensor(float("-inf"), device=x.device))
+
+ upper_bound[torch.abs(upper_bound) < 1] = 0
+ lower_bound[torch.abs(lower_bound) < 1] = 0
+
+ # Define the 3D Laplacian kernel
+ laplacian_kernel = torch.tensor(
+ [
+ [
+ [
+ [[0, 1, 0], [1, -6, 1], [0, 1, 0]],
+ [[1, 0, 1], [0, 0, 0], [1, 0, 1]],
+ [[0, 1, 0], [1, 0, 1], [0, 1, 0]],
+ ]
+ ]
+ ],
+ device=x.device,
+ ).float()
+
+ laplacian_kernel = laplacian_kernel / laplacian_kernel.abs().sum()
+
+ breakpoint()
+
+ # Simplified Jacobi iteration
+ for i in range(max_iters):
+ # Reshape x to 5D tensor (batch, channel, depth, height, width)
+ x_5d = x.view(1, 1, *band.shape)
+ x_3d = x.view(*band.shape)
+
+ # Apply 3D convolution
+ laplacian = F.conv3d(x_5d, laplacian_kernel, padding=1)
+
+ # Reshape back to original dimensions
+ laplacian = laplacian.view(x.shape)
+
+ # Use a small relaxation factor to improve stability
+ relaxation_factor = 0.1
+ tv_weight = 0.1
+ # x_new = x + relaxation_factor * laplacian
+ x_new = total_variation_denoising(x_3d, weight=tv_weight)
+ # Print laplacian min and max
+ # print(f"Laplacian min: {laplacian.min().item():.4f}, max: {laplacian.max().item():.4f}")
+
+ # Apply constraints
+ # Reshape x_new to match the dimensions of lower_bound and upper_bound
+ x_new = x_new.view(x.shape)
+ x_new = torch.clamp(x_new, min=lower_bound, max=upper_bound)
+
+ # Check for convergence
+ diff_norm = torch.norm(x_new - x)
+ print(diff_norm)
+ x_norm = torch.norm(x)
+
+ if x_norm > 1e-8: # Avoid division by very small numbers
+ relative_change = diff_norm / x_norm
+ if relative_change < rel_tol:
+ break
+ elif diff_norm < rel_tol: # If x_norm is very small, check absolute change
+ break
+
+ x = x_new
+
+ # Check for NaN and break if found, also check for inf
+ if torch.isnan(x).any() or torch.isinf(x).any():
+ print(f"NaN or Inf detected at iteration {i}")
+ breakpoint()
+ break
+
+ result = distance.clone()
+ result[band] = x
+ return result
+
+
+def smooth_gaussian(binary_array: np.ndarray, sigma: float = 3) -> np.ndarray:
+ vol = np.float_(binary_array) - 0.5
+ return ndi.gaussian_filter(vol, sigma=sigma)
+
+
+def smooth_gaussian_gpu(binary_array: torch.Tensor, sigma: float = 3):
+ # vol = binary_array.float()
+ vol = binary_array
+ kernel_size = int(2 * sigma + 1)
+ kernel = torch.ones(
+ 1,
+ 1,
+ kernel_size,
+ kernel_size,
+ kernel_size,
+ device=binary_array.device,
+ dtype=vol.dtype,
+ ) / (kernel_size**3)
+ return F.conv3d(
+ vol.unsqueeze(0).unsqueeze(0), kernel, padding=kernel_size // 2
+ ).squeeze()
+
+
+def smooth(binary_array: np.ndarray, method: str = "auto", **kwargs) -> np.ndarray:
+ """
+ Smooths the 0.5 level-set of a binary array. Returns a floating-point
+ array with a smoothed version of the original level-set in the 0 isovalue.
+
+ This function can apply two different methods:
+
+ - A constrained smoothing method which preserves details and fine
+ structures, but it is slow and requires a large amount of memory. This
+ method is recommended when the input array is small (smaller than
+ (500, 500, 500)).
+ - A Gaussian filter applied over the binary array. This method is fast, but
+ not very precise, as it can destroy fine details. It is only recommended
+ when the input array is large and the 0.5 level-set does not contain
+ thin structures.
+
+ Parameters
+ ----------
+ binary_array : ndarray
+ Input binary array with the 0.5 level-set to smooth.
+ method : str, one of ['auto', 'gaussian', 'constrained']
+ Smoothing method. If 'auto' is given, the method will be automatically
+ chosen based on the size of `binary_array`.
+
+ Parameters for 'gaussian'
+ -------------------------
+ sigma : float
+ Size of the Gaussian filter (default 3).
+
+ Parameters for 'constrained'
+ ----------------------------
+ max_iters : positive integer
+ Number of iterations of the constrained optimization method
+ (default 250).
+ rel_tol: float
+ Relative tolerance as a stopping criterion (default 1e-6).
+
+ Output
+ ------
+ res : ndarray
+ Floating-point array with a smoothed 0 level-set.
+ """
+
+ binary_array = np.asarray(binary_array)
+
+ if method == "auto":
+ if binary_array.size > 512**3:
+ method = "gaussian"
+ else:
+ method = "constrained"
+
+ if method == "gaussian":
+ return smooth_gaussian(binary_array, **kwargs)
+
+ if method == "constrained":
+ return smooth_constrained(binary_array, **kwargs)
+
+ raise ValueError("Unknown method '{}'".format(method))
+
+
+def smooth_gpu(binary_array: torch.Tensor, method: str = "auto", **kwargs):
+ if method == "auto":
+ method = "gaussian" if binary_array.numel() > 512**3 else "constrained"
+
+ if method == "gaussian":
+ return smooth_gaussian_gpu(binary_array, **kwargs)
+ elif method == "constrained":
+ return smooth_constrained_gpu(binary_array, **kwargs)
+ else:
+ raise ValueError(f"Unknown method '{method}'")
diff --git a/src/utils/train_utils.py b/src/utils/train_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cfaa6ca0ddbed3b023de8d6cbacc819b8b98a8d
--- /dev/null
+++ b/src/utils/train_utils.py
@@ -0,0 +1,184 @@
+from src.utils.typing_utils import *
+
+import os
+from omegaconf import OmegaConf
+
+from torch import optim
+from torch.optim import lr_scheduler
+from diffusers.training_utils import *
+from diffusers.optimization import get_scheduler
+
+# https://github.com/huggingface/diffusers/pull/9812: fix `self.use_ema_warmup`
+class MyEMAModel(EMAModel):
+ """
+ Exponential Moving Average of models weights
+ """
+
+ def __init__(
+ self,
+ parameters: Iterable[torch.nn.Parameter],
+ decay: float = 0.9999,
+ min_decay: float = 0.0,
+ update_after_step: int = 0,
+ use_ema_warmup: bool = False,
+ inv_gamma: Union[float, int] = 1.0,
+ power: Union[float, int] = 2 / 3,
+ foreach: bool = False,
+ model_cls: Optional[Any] = None,
+ model_config: Dict[str, Any] = None,
+ **kwargs,
+ ):
+ """
+ Args:
+ parameters (Iterable[torch.nn.Parameter]): The parameters to track.
+ decay (float): The decay factor for the exponential moving average.
+ min_decay (float): The minimum decay factor for the exponential moving average.
+ update_after_step (int): The number of steps to wait before starting to update the EMA weights.
+ use_ema_warmup (bool): Whether to use EMA warmup.
+ inv_gamma (float):
+ Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
+ power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
+ foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster.
+ device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
+ weights will be stored on CPU.
+
+ @crowsonkb's notes on EMA Warmup:
+ If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
+ to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
+ gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
+ at 215.4k steps).
+ """
+
+ if isinstance(parameters, torch.nn.Module):
+ deprecation_message = (
+ "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. "
+ "Please pass the parameters of the module instead."
+ )
+ deprecate(
+ "passing a `torch.nn.Module` to `ExponentialMovingAverage`",
+ "1.0.0",
+ deprecation_message,
+ standard_warn=False,
+ )
+ parameters = parameters.parameters()
+
+ # # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility
+ # use_ema_warmup = True
+
+ if kwargs.get("max_value", None) is not None:
+ deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead."
+ deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False)
+ decay = kwargs["max_value"]
+
+ if kwargs.get("min_value", None) is not None:
+ deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead."
+ deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False)
+ min_decay = kwargs["min_value"]
+
+ parameters = list(parameters)
+ self.shadow_params = [p.clone().detach() for p in parameters]
+
+ if kwargs.get("device", None) is not None:
+ deprecation_message = "The `device` argument is deprecated. Please use `to` instead."
+ deprecate("device", "1.0.0", deprecation_message, standard_warn=False)
+ self.to(device=kwargs["device"])
+
+ self.temp_stored_params = None
+
+ self.decay = decay
+ self.min_decay = min_decay
+ self.update_after_step = update_after_step
+ self.use_ema_warmup = use_ema_warmup
+ self.inv_gamma = inv_gamma
+ self.power = power
+ self.optimization_step = 0
+ self.cur_decay_value = None # set in `step()`
+ self.foreach = foreach
+
+ self.model_cls = model_cls
+ self.model_config = model_config
+
+ def get_decay(self, optimization_step: int) -> float:
+ """
+ Compute the decay factor for the exponential moving average.
+ """
+ step = max(0, optimization_step - self.update_after_step - 1)
+
+ if step <= 0:
+ return 0.0
+
+ if self.use_ema_warmup:
+ cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
+ else:
+ # cur_decay_value = (1 + step) / (10 + step)
+ cur_decay_value = self.decay
+
+ cur_decay_value = min(cur_decay_value, self.decay)
+ # make sure decay is not smaller than min_decay
+ cur_decay_value = max(cur_decay_value, self.min_decay)
+ return cur_decay_value
+
+def get_configs(yaml_path: str, cli_configs: List[str]=[], **kwargs) -> DictConfig:
+ yaml_configs = OmegaConf.load(yaml_path)
+ cli_configs = OmegaConf.from_cli(cli_configs)
+
+ configs = OmegaConf.merge(yaml_configs, cli_configs, kwargs)
+ OmegaConf.resolve(configs) # resolve ${...} placeholders
+ return configs
+
+def get_optimizer(name: str, params: Parameter, **kwargs) -> Optimizer:
+ if name == "adamw":
+ return optim.AdamW(params=params, **kwargs)
+ else:
+ raise NotImplementedError(f"Not implemented optimizer: {name}")
+
+def get_lr_scheduler(name: str, optimizer: Optimizer, **kwargs) -> LRScheduler:
+ if name == "one_cycle":
+ return lr_scheduler.OneCycleLR(
+ optimizer,
+ max_lr=kwargs["max_lr"],
+ total_steps=kwargs["total_steps"],
+ pct_start=kwargs["pct_start"],
+ )
+ elif name == "cosine_warmup":
+ return get_scheduler(
+ "cosine", optimizer,
+ num_warmup_steps=kwargs["num_warmup_steps"],
+ num_training_steps=kwargs["total_steps"],
+ )
+ elif name == "constant_warmup":
+ return get_scheduler(
+ "constant_with_warmup", optimizer,
+ num_warmup_steps=kwargs["num_warmup_steps"],
+ num_training_steps=kwargs["total_steps"],
+ )
+ elif name == "constant":
+ return lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda _: 1)
+ elif name == "linear_decay":
+ return lr_scheduler.LambdaLR(
+ optimizer=optimizer,
+ lr_lambda=lambda epoch: max(0., 1. - epoch / kwargs["total_epochs"]),
+ )
+ else:
+ raise NotImplementedError(f"Not implemented lr scheduler: {name}")
+
+def save_experiment_params(
+ args: Namespace,
+ configs: DictConfig,
+ save_dir: str
+) -> Dict[str, Any]:
+ params = OmegaConf.merge(configs, {"args": {k: str(v) for k, v in vars(args).items()}})
+ OmegaConf.save(params, os.path.join(save_dir, "params.yaml"))
+ return dict(params)
+
+
+def save_model_architecture(model: Module, save_dir: str) -> None:
+ num_buffers = sum(b.numel() for b in model.buffers())
+ num_params = sum(p.numel() for p in model.parameters())
+ num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ message = f"Number of buffers: {num_buffers}\n" +\
+ f"Number of trainable / all parameters: {num_trainable_params} / {num_params}\n\n" +\
+ f"Model architecture:\n{model}"
+
+ with open(os.path.join(save_dir, "model.txt"), "w") as f:
+ f.write(message)
\ No newline at end of file
diff --git a/src/utils/typing_utils.py b/src/utils/typing_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..862515ea55fd1c01d03c824547487c2b89f96eff
--- /dev/null
+++ b/src/utils/typing_utils.py
@@ -0,0 +1,20 @@
+from typing import *
+
+from argparse import Namespace
+from collections import defaultdict
+from omegaconf import DictConfig, ListConfig
+from omegaconf.base import ContainerMetadata, Metadata
+from omegaconf.nodes import AnyNode
+
+from torch import Tensor
+from torch.nn import Parameter, Module
+from torch.nn.parallel import DistributedDataParallel
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import LRScheduler
+from torch.utils.data import DataLoader
+
+from accelerate.optimizer import AcceleratedOptimizer
+from accelerate.scheduler import AcceleratedScheduler
+from accelerate.data_loader import DataLoaderShard
+
+