Spaces:
Paused
Paused
Upload 65 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- Dockerfile +88 -0
- LICENSE.md +159 -0
- NeuralJacobianFields/MeshProcessor.py +473 -0
- NeuralJacobianFields/PoissonSystem.py +796 -0
- NeuralJacobianFields/SourceMesh.py +152 -0
- app.py +230 -0
- asset_visualization/armor.gif +3 -0
- example_config.yml +56 -0
- get_embeddings.py +58 -0
- loop.py +436 -0
- main.py +93 -0
- meshes/UV.png +3 -0
- meshes/dress_shortsleeve.obj +0 -0
- meshes/longsleeve.mtl +13 -0
- meshes/longsleeve.obj +0 -0
- meshes/poncho.obj +0 -0
- meshes/tanktop.obj +0 -0
- meshes/tshirt.obj +0 -0
- meshes_target/jacket_sdf_new.obj +3 -0
- nvdiffmodeling/LICENSE.txt +97 -0
- nvdiffmodeling/src/material.py +149 -0
- nvdiffmodeling/src/mesh.py +510 -0
- nvdiffmodeling/src/obj.py +215 -0
- nvdiffmodeling/src/regularizer.py +197 -0
- nvdiffmodeling/src/render.py +223 -0
- nvdiffmodeling/src/renderutils/__init__.py +10 -0
- nvdiffmodeling/src/renderutils/bsdf.py +126 -0
- nvdiffmodeling/src/renderutils/c_src/bsdf.cu +551 -0
- nvdiffmodeling/src/renderutils/c_src/bsdf.h +70 -0
- nvdiffmodeling/src/renderutils/c_src/common.cpp +71 -0
- nvdiffmodeling/src/renderutils/c_src/common.h +38 -0
- nvdiffmodeling/src/renderutils/c_src/loss.cu +207 -0
- nvdiffmodeling/src/renderutils/c_src/loss.h +35 -0
- nvdiffmodeling/src/renderutils/c_src/mesh.cu +90 -0
- nvdiffmodeling/src/renderutils/c_src/mesh.h +20 -0
- nvdiffmodeling/src/renderutils/c_src/normal.cu +179 -0
- nvdiffmodeling/src/renderutils/c_src/normal.h +24 -0
- nvdiffmodeling/src/renderutils/c_src/tensor.h +86 -0
- nvdiffmodeling/src/renderutils/c_src/torch_bindings.cpp +793 -0
- nvdiffmodeling/src/renderutils/c_src/vec3f.h +106 -0
- nvdiffmodeling/src/renderutils/c_src/vec4f.h +22 -0
- nvdiffmodeling/src/renderutils/loss.py +40 -0
- nvdiffmodeling/src/renderutils/ops.py +425 -0
- nvdiffmodeling/src/renderutils/tests/test_bsdf.py +266 -0
- nvdiffmodeling/src/renderutils/tests/test_loss.py +61 -0
- nvdiffmodeling/src/renderutils/tests/test_mesh.py +90 -0
- nvdiffmodeling/src/renderutils/tests/test_perf.py +58 -0
- nvdiffmodeling/src/texture.py +151 -0
- nvdiffmodeling/src/util.py +354 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
asset_visualization/armor.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
meshes_target/jacket_sdf_new.obj filter=lfs diff=lfs merge=lfs -text
|
38 |
+
meshes/UV.png filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use CUDA base image
|
2 |
+
FROM nvidia/cuda:11.8-devel-ubuntu20.04
|
3 |
+
|
4 |
+
# Set environment variables
|
5 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
6 |
+
ENV PYTHONUNBUFFERED=1
|
7 |
+
ENV CUDA_HOME=/usr/local/cuda
|
8 |
+
ENV PATH=${CUDA_HOME}/bin:${PATH}
|
9 |
+
ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
|
10 |
+
|
11 |
+
# Install system dependencies
|
12 |
+
RUN apt-get update && apt-get install -y \
|
13 |
+
python3.8 \
|
14 |
+
python3.8-dev \
|
15 |
+
python3-pip \
|
16 |
+
git \
|
17 |
+
wget \
|
18 |
+
curl \
|
19 |
+
build-essential \
|
20 |
+
cmake \
|
21 |
+
ninja-build \
|
22 |
+
libgl1-mesa-glx \
|
23 |
+
libglib2.0-0 \
|
24 |
+
libsm6 \
|
25 |
+
libxext6 \
|
26 |
+
libxrender-dev \
|
27 |
+
libgomp1 \
|
28 |
+
&& rm -rf /var/lib/apt/lists/*
|
29 |
+
|
30 |
+
# Create symbolic link for python
|
31 |
+
RUN ln -s /usr/bin/python3.8 /usr/bin/python
|
32 |
+
|
33 |
+
# Upgrade pip
|
34 |
+
RUN python -m pip install --upgrade pip
|
35 |
+
|
36 |
+
# Set working directory
|
37 |
+
WORKDIR /app
|
38 |
+
|
39 |
+
# Copy requirements first for better caching
|
40 |
+
COPY requirements.txt .
|
41 |
+
|
42 |
+
# Install PyTorch with CUDA support
|
43 |
+
RUN pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118
|
44 |
+
|
45 |
+
# Install basic dependencies
|
46 |
+
RUN pip install -r requirements.txt
|
47 |
+
|
48 |
+
# Install CLIP
|
49 |
+
RUN pip install git+https://github.com/openai/CLIP.git
|
50 |
+
|
51 |
+
# Create packages directory
|
52 |
+
RUN mkdir -p packages
|
53 |
+
|
54 |
+
# Install nvdiffrast
|
55 |
+
WORKDIR /app/packages
|
56 |
+
RUN git clone https://github.com/NVlabs/nvdiffrast.git && \
|
57 |
+
cd nvdiffrast && \
|
58 |
+
pip install .
|
59 |
+
|
60 |
+
# Install PyTorch3D
|
61 |
+
WORKDIR /app/packages
|
62 |
+
RUN git clone https://github.com/facebookresearch/pytorch3d.git && \
|
63 |
+
cd pytorch3d && \
|
64 |
+
pip install .
|
65 |
+
|
66 |
+
# Install Fashion-CLIP
|
67 |
+
WORKDIR /app/packages
|
68 |
+
RUN git clone https://github.com/patrickjohncyh/fashion-clip.git && \
|
69 |
+
cd fashion-clip && \
|
70 |
+
pip install appdirs boto3 annoy validators transformers datasets
|
71 |
+
|
72 |
+
# Return to app directory
|
73 |
+
WORKDIR /app
|
74 |
+
|
75 |
+
# Copy application files
|
76 |
+
COPY . .
|
77 |
+
|
78 |
+
# Create necessary directories
|
79 |
+
RUN mkdir -p outputs meshes meshes_target
|
80 |
+
|
81 |
+
# Make setup script executable
|
82 |
+
RUN chmod +x setup_spaces.py
|
83 |
+
|
84 |
+
# Expose port
|
85 |
+
EXPOSE 7860
|
86 |
+
|
87 |
+
# Set the default command
|
88 |
+
CMD ["python", "app.py"]
|
LICENSE.md
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Attribution-NonCommercial 4.0 International
|
2 |
+
|
3 |
+
> *Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.*
|
4 |
+
>
|
5 |
+
> ### Using Creative Commons Public Licenses
|
6 |
+
>
|
7 |
+
> Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
|
8 |
+
>
|
9 |
+
> * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
|
10 |
+
>
|
11 |
+
> * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
|
12 |
+
|
13 |
+
## Creative Commons Attribution-NonCommercial 4.0 International Public License
|
14 |
+
|
15 |
+
By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
|
16 |
+
|
17 |
+
### Section 1 – Definitions.
|
18 |
+
|
19 |
+
a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
|
20 |
+
|
21 |
+
b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
|
22 |
+
|
23 |
+
c. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
|
24 |
+
|
25 |
+
d. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
|
26 |
+
|
27 |
+
e. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
|
28 |
+
|
29 |
+
f. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
|
30 |
+
|
31 |
+
g. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
|
32 |
+
|
33 |
+
h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
|
34 |
+
|
35 |
+
i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
|
36 |
+
|
37 |
+
j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
|
38 |
+
|
39 |
+
k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
|
40 |
+
|
41 |
+
l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
|
42 |
+
|
43 |
+
### Section 2 – Scope.
|
44 |
+
|
45 |
+
a. ___License grant.___
|
46 |
+
|
47 |
+
1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
|
48 |
+
|
49 |
+
A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
|
50 |
+
|
51 |
+
B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
|
52 |
+
|
53 |
+
2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
|
54 |
+
|
55 |
+
3. __Term.__ The term of this Public License is specified in Section 6(a).
|
56 |
+
|
57 |
+
4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
|
58 |
+
|
59 |
+
5. __Downstream recipients.__
|
60 |
+
|
61 |
+
A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
|
62 |
+
|
63 |
+
B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
|
64 |
+
|
65 |
+
6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
|
66 |
+
|
67 |
+
b. ___Other rights.___
|
68 |
+
|
69 |
+
1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
|
70 |
+
|
71 |
+
2. Patent and trademark rights are not licensed under this Public License.
|
72 |
+
|
73 |
+
3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
|
74 |
+
|
75 |
+
### Section 3 – License Conditions.
|
76 |
+
|
77 |
+
Your exercise of the Licensed Rights is expressly made subject to the following conditions.
|
78 |
+
|
79 |
+
a. ___Attribution.___
|
80 |
+
|
81 |
+
1. If You Share the Licensed Material (including in modified form), You must:
|
82 |
+
|
83 |
+
A. retain the following if it is supplied by the Licensor with the Licensed Material:
|
84 |
+
|
85 |
+
i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
|
86 |
+
|
87 |
+
ii. a copyright notice;
|
88 |
+
|
89 |
+
iii. a notice that refers to this Public License;
|
90 |
+
|
91 |
+
iv. a notice that refers to the disclaimer of warranties;
|
92 |
+
|
93 |
+
v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
|
94 |
+
|
95 |
+
B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
|
96 |
+
|
97 |
+
C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
|
98 |
+
|
99 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
|
100 |
+
|
101 |
+
3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
|
102 |
+
|
103 |
+
4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License.
|
104 |
+
|
105 |
+
### Section 4 – Sui Generis Database Rights.
|
106 |
+
|
107 |
+
Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
|
108 |
+
|
109 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
|
110 |
+
|
111 |
+
b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and
|
112 |
+
|
113 |
+
c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
|
114 |
+
|
115 |
+
For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
|
116 |
+
|
117 |
+
### Section 5 – Disclaimer of Warranties and Limitation of Liability.
|
118 |
+
|
119 |
+
a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
|
120 |
+
|
121 |
+
b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
|
122 |
+
|
123 |
+
c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
|
124 |
+
|
125 |
+
### Section 6 – Term and Termination.
|
126 |
+
|
127 |
+
a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
|
128 |
+
|
129 |
+
b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
|
130 |
+
|
131 |
+
1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
|
132 |
+
|
133 |
+
2. upon express reinstatement by the Licensor.
|
134 |
+
|
135 |
+
For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
|
136 |
+
|
137 |
+
c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
|
138 |
+
|
139 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
|
140 |
+
|
141 |
+
### Section 7 – Other Terms and Conditions.
|
142 |
+
|
143 |
+
a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
|
144 |
+
|
145 |
+
b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
|
146 |
+
|
147 |
+
### Section 8 – Interpretation.
|
148 |
+
|
149 |
+
a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
|
150 |
+
|
151 |
+
b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
|
152 |
+
|
153 |
+
c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
|
154 |
+
|
155 |
+
d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
|
156 |
+
|
157 |
+
> Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
|
158 |
+
>
|
159 |
+
> Creative Commons may be contacted at creativecommons.org
|
NeuralJacobianFields/MeshProcessor.py
ADDED
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from multiprocessing import process
|
2 |
+
import warnings
|
3 |
+
warnings.filterwarnings("ignore")
|
4 |
+
|
5 |
+
|
6 |
+
from scipy.sparse import load_npz, save_npz
|
7 |
+
from .PoissonSystem import poisson_system_matrices_from_mesh, PoissonSystemMatrices, SparseMat
|
8 |
+
import os
|
9 |
+
import trimesh
|
10 |
+
from easydict import EasyDict
|
11 |
+
import numpy
|
12 |
+
import numpy as np
|
13 |
+
import scipy
|
14 |
+
import scipy.sparse
|
15 |
+
import igl
|
16 |
+
from scipy.sparse import save_npz
|
17 |
+
from time import time
|
18 |
+
import torch
|
19 |
+
|
20 |
+
NUM_SAMPLES = 1024
|
21 |
+
WKS_DIM = 100
|
22 |
+
|
23 |
+
class MeshProcessor:
|
24 |
+
'''
|
25 |
+
Extracts all preprocessing-related data (sample points for pointnet; wave-kernel-signature, etc.)
|
26 |
+
'''
|
27 |
+
def __init__(self, vertices, faces, ttype, source_dir=None,from_file = False,
|
28 |
+
cpuonly=False, load_wks_samples=False, load_wks_centroids=False,
|
29 |
+
compute_splu=True, load_splu=False):
|
30 |
+
'''
|
31 |
+
:param vertices:
|
32 |
+
:param faces:
|
33 |
+
:param ttype: the torch data type to use (float, half, double)
|
34 |
+
:param source_dir: the directory to load the preprocessed data from; if given, will try to load the data before computing, if not given, always compute
|
35 |
+
'''
|
36 |
+
|
37 |
+
self.ttype = ttype
|
38 |
+
self.num_samples = NUM_SAMPLES
|
39 |
+
self.vertices = vertices.squeeze()
|
40 |
+
self.faces = faces.squeeze()
|
41 |
+
self.normals = igl.per_vertex_normals(self.vertices, self.faces)
|
42 |
+
# self.__use_wks = use_wks
|
43 |
+
self.samples = EasyDict()
|
44 |
+
self.samples.xyz = None
|
45 |
+
self.samples.normals = None
|
46 |
+
self.samples.wks = None
|
47 |
+
self.centroids = EasyDict()
|
48 |
+
self.centroids.points_and_normals = None
|
49 |
+
self.centroids.wks = None
|
50 |
+
self.diff_ops = EasyDict()
|
51 |
+
self.diff_ops.splu = EasyDict()
|
52 |
+
self.diff_ops.splu.L = None
|
53 |
+
self.diff_ops.splu.U = None
|
54 |
+
self.diff_ops.splu.perm_c = None
|
55 |
+
self.diff_ops.splu.perm_r = None
|
56 |
+
self.diff_ops.frames = None
|
57 |
+
self.diff_ops.rhs = None
|
58 |
+
self.diff_ops.grad = None
|
59 |
+
self.diff_ops.poisson_sys_mat = None
|
60 |
+
self.faces_wks = None
|
61 |
+
self.vert_wks = None
|
62 |
+
self.diff_ops.poisson = None
|
63 |
+
self.source_dir = source_dir
|
64 |
+
self.from_file = from_file
|
65 |
+
self.cpuonly = cpuonly
|
66 |
+
self.load_wks_samples = load_wks_samples
|
67 |
+
self.load_wks_centroids = load_wks_centroids
|
68 |
+
self.compute_splu = compute_splu
|
69 |
+
self.load_splu = load_splu
|
70 |
+
|
71 |
+
@staticmethod
|
72 |
+
def meshprocessor_from_directory(source_dir, ttype, cpuonly=False, load_wks_samples=False, load_wks_centroids=False):
|
73 |
+
try:
|
74 |
+
vertices = np.load(os.path.join(source_dir, "vertices.npy"))
|
75 |
+
faces = np.load(os.path.join(source_dir, "faces.npy"))
|
76 |
+
except:
|
77 |
+
print(os.path.join(source_dir, "vertices.npy"))
|
78 |
+
import traceback
|
79 |
+
traceback.print_exc()
|
80 |
+
return MeshProcessor(vertices,faces,ttype,source_dir, cpuonly=cpuonly, load_wks_samples=load_wks_samples, load_wks_centroids=load_wks_centroids, compute_splu=False)
|
81 |
+
|
82 |
+
@staticmethod
|
83 |
+
def meshprocessor_from_file(fname, ttype, cpuonly=False, load_wks_samples=False, load_wks_centroids=False):
|
84 |
+
if fname[-4:] == '.obj':
|
85 |
+
V, _, _, F, _, _ = igl.read_obj(fname)
|
86 |
+
elif fname[-4:] == '.off':
|
87 |
+
V,F,_ = igl.read_off(fname)
|
88 |
+
elif fname[-4:] == '.ply':
|
89 |
+
V,F = igl.read_triangle_mesh(fname)
|
90 |
+
return MeshProcessor(V,F,ttype,os.path.dirname(fname),True, cpuonly=cpuonly, load_wks_samples=load_wks_samples, load_wks_centroids=load_wks_centroids, compute_splu=False)
|
91 |
+
|
92 |
+
@staticmethod
|
93 |
+
def meshprocessor_from_array(vertices, faces, source_dir, ttype, cpuonly=False, load_wks_samples=False, load_wks_centroids=False):
|
94 |
+
return MeshProcessor(vertices,faces,ttype,source_dir, cpuonly=cpuonly, load_wks_samples=load_wks_samples, load_wks_centroids=load_wks_centroids, compute_splu=False)
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
def get_vertices(self):
|
99 |
+
return self.vertices
|
100 |
+
|
101 |
+
def get_faces(self):
|
102 |
+
return self.faces
|
103 |
+
|
104 |
+
def load_centroids(self):
|
105 |
+
self.centroids.points_and_normals = np.load(os.path.join(self.source_dir, "centroids_and_normals.npy"))
|
106 |
+
if self.load_wks_centroids:
|
107 |
+
self.centroids.wks = np.load(os.path.join(self.source_dir, "centroids_wks.npy"))
|
108 |
+
|
109 |
+
def get_samples(self):
|
110 |
+
if self.samples.xyz is None:
|
111 |
+
if True:
|
112 |
+
try:
|
113 |
+
self.load_samples()
|
114 |
+
except Exception as e:
|
115 |
+
self.compute_samples()
|
116 |
+
self.save_samples()
|
117 |
+
return self.samples
|
118 |
+
|
119 |
+
def load_samples(self):
|
120 |
+
if self.samples.xyz is None:
|
121 |
+
self.samples.xyz = np.load(os.path.join(self.source_dir, 'samples.npy'))
|
122 |
+
if self.samples.normals is None:
|
123 |
+
self.samples.normals = np.load(os.path.join(self.source_dir, 'samples_normals.npy'))
|
124 |
+
if self.load_wks_samples:
|
125 |
+
if self.samples.wks is None:
|
126 |
+
self.samples.wks = np.load(os.path.join(self.source_dir, 'samples_wks.npy'))
|
127 |
+
if self.centroids.wks is None:
|
128 |
+
self.centroids.wks = np.load(os.path.join(self.source_dir, 'centroid_wks.npy'))
|
129 |
+
|
130 |
+
def save_samples(self):
|
131 |
+
os.makedirs(self.source_dir, exist_ok=True)
|
132 |
+
np.save(os.path.join(self.source_dir, 'samples.npy'), self.samples.xyz)
|
133 |
+
np.save(os.path.join(self.source_dir, 'samples_normals.npy'), self.samples.normals)
|
134 |
+
if self.load_wks_samples:
|
135 |
+
np.save(os.path.join(self.source_dir, 'samples_wks.npy'), self.samples.wks)
|
136 |
+
np.save(os.path.join(self.source_dir, 'centroid_wks.npy'), self.centroids.wks)
|
137 |
+
|
138 |
+
def compute_samples(self):
|
139 |
+
sstime = time()
|
140 |
+
if self.load_wks_centroids or self.load_wks_centroids:
|
141 |
+
self.computeWKS()
|
142 |
+
# print(f"WKS {time() - sstime}")
|
143 |
+
pt_samples, normals_samples, wks_samples, bary = self.sample_points( self.num_samples)
|
144 |
+
self.samples.xyz = pt_samples
|
145 |
+
self.samples.normals = normals_samples
|
146 |
+
self.samples.wks = wks_samples
|
147 |
+
self.centroids.wks = self.faces_wks
|
148 |
+
|
149 |
+
def get_centroids(self):
|
150 |
+
if self.centroids.points_and_normals is None:
|
151 |
+
if True:
|
152 |
+
try:
|
153 |
+
self.load_centroids()
|
154 |
+
except Exception as e:
|
155 |
+
self.compute_centroids()
|
156 |
+
# self.save_centroids() # centroid WKS and samples WKS are intertwined right now and you cannot really use one without the other. So this is redondont with function save_samples
|
157 |
+
return self.centroids
|
158 |
+
|
159 |
+
def compute_centroids(self):
|
160 |
+
m = trimesh.Trimesh(vertices=self.vertices, faces=self.faces, process=False)
|
161 |
+
self.centroids.points_and_normals = np.hstack((np.mean(m.triangles, axis=1), m.face_normals))
|
162 |
+
self.get_samples()# this is to compute WKS for centroids
|
163 |
+
|
164 |
+
def get_differential_operators(self):
|
165 |
+
if self.diff_ops.grad is None:
|
166 |
+
if True:
|
167 |
+
try:
|
168 |
+
self.load_differential_operators()
|
169 |
+
except Exception as e:
|
170 |
+
warnings.warn(f'while loading data, got file not exists exception: {e} ')
|
171 |
+
self.compute_differential_operators()
|
172 |
+
self.save_differential_operators()
|
173 |
+
if self.load_splu:
|
174 |
+
self.get_poisson_system()
|
175 |
+
return self.diff_ops
|
176 |
+
|
177 |
+
|
178 |
+
def load_poisson_system(self):
|
179 |
+
try:
|
180 |
+
self.diff_ops.splu.L = load_npz(os.path.join(self.source_dir, 'lap_L.npz'))
|
181 |
+
self.diff_ops.splu.U = load_npz(os.path.join(self.source_dir, 'lap_U.npz'))
|
182 |
+
self.diff_ops.splu.perm_c = np.load(os.path.join(self.source_dir, 'lap_perm_c.npy'))
|
183 |
+
self.diff_ops.splu.perm_r = np.load(os.path.join(self.source_dir, 'lap_perm_r.npy'))
|
184 |
+
except:
|
185 |
+
print(f"FAILED load poisson on: {os.path.join(self.source_dir)}")
|
186 |
+
raise Exception("FAILED load poisson on: {os.path.join(self.source_dir)}")
|
187 |
+
|
188 |
+
|
189 |
+
def load_differential_operators(self):
|
190 |
+
self.diff_ops.rhs = SparseMat.from_coo(load_npz(os.path.join(self.source_dir, 'new_rhs.npz')), ttype=torch.float64)
|
191 |
+
self.diff_ops.grad = SparseMat.from_coo(load_npz(os.path.join(self.source_dir, 'new_grad.npz')), ttype=torch.float64)
|
192 |
+
self.diff_ops.frames = np.load(os.path.join(self.source_dir, 'w.npy'))
|
193 |
+
self.diff_ops.laplacian = SparseMat.from_coo(load_npz(os.path.join(self.source_dir, 'laplacian.npz')), ttype=torch.float64)
|
194 |
+
|
195 |
+
def save_differential_operators(self):
|
196 |
+
save_npz(os.path.join(self.source_dir, 'new_rhs.npz'), self.diff_ops.rhs.to_coo())
|
197 |
+
save_npz(os.path.join(self.source_dir, 'new_grad.npz'), self.diff_ops.grad.to_coo())
|
198 |
+
np.save(os.path.join(self.source_dir, 'w.npy'), self.diff_ops.frames)
|
199 |
+
save_npz(os.path.join(self.source_dir, 'laplacian.npz'), self.diff_ops.laplacian.to_coo())
|
200 |
+
|
201 |
+
def compute_differential_operators(self):
|
202 |
+
'''
|
203 |
+
process the given mesh
|
204 |
+
'''
|
205 |
+
poisson_sys_mat = poisson_system_matrices_from_mesh(V= self.vertices, F=self.faces, cpuonly=self.cpuonly)
|
206 |
+
self.diff_ops.grad = poisson_sys_mat.igl_grad
|
207 |
+
self.diff_ops.rhs = poisson_sys_mat.rhs
|
208 |
+
self.diff_ops.laplacian = poisson_sys_mat.lap
|
209 |
+
self.diff_ops.frames = poisson_sys_mat.w
|
210 |
+
self.diff_ops.poisson_sys_mat = poisson_sys_mat
|
211 |
+
|
212 |
+
|
213 |
+
def compute_poisson(self):
|
214 |
+
poissonsolver = poissonbuilder.compute_poisson_solver_from_laplacian(compute_splu=self.compute_splu)
|
215 |
+
# new_grad = poissonbuilder.get_new_grad() # This is now done in poisson_system_matrices_from_mesh
|
216 |
+
if self.compute_splu:
|
217 |
+
self.diff_ops.splu.L, self.diff_ops.splu.U , self.diff_ops.splu.perm_c , self.diff_ops.splu.perm_r = poissonbuilder.compute_splu()
|
218 |
+
self.diff_ops.frames = poissonbuilder.w
|
219 |
+
|
220 |
+
|
221 |
+
def prepare_differential_operators_for_use(self,ttype):
|
222 |
+
diff_ops = self.get_differential_operators() # call 1
|
223 |
+
## WARNING : we commented these two lines because they seemed redundant.
|
224 |
+
if self.diff_ops.poisson_sys_mat is None: # not created if loaded from disk the diff ops
|
225 |
+
diff_ops.poisson_sys_mat = PoissonSystemMatrices(self.vertices, self.faces, diff_ops.grad, diff_ops.rhs, diff_ops.frames, ttype, lap = diff_ops.laplacian, cpuonly=self.cpuonly)
|
226 |
+
|
227 |
+
self.diff_ops.poisson_solver = diff_ops.poisson_sys_mat.create_poisson_solver() # call 2
|
228 |
+
self.diff_ops.MyCuSPLU_solver = diff_ops.poisson_sys_mat.create_poisson_solver() #create_poisson_solver_from_splu_old
|
229 |
+
|
230 |
+
def get_writeable(self):
|
231 |
+
'''
|
232 |
+
get dictionaries to write numpy and npz
|
233 |
+
:return: two args, np, npz, each dicts with field_name --> data to save
|
234 |
+
'''
|
235 |
+
out_np = {}
|
236 |
+
out_npz = {}
|
237 |
+
out_np['vertices'] = self.vertices
|
238 |
+
out_np['faces'] = self.faces
|
239 |
+
if self.samples is not None:
|
240 |
+
out_np["samples"] = self.samples.xyz
|
241 |
+
out_np["samples_normals"] = self.samples.normals
|
242 |
+
out_np["samples_wks"] = self.samples.wks
|
243 |
+
if self.centroids is not None:
|
244 |
+
out_np["centroids_wks"] = self.centroids.wks
|
245 |
+
out_np["centroids_and_normals"] = self.centroids.points_and_normals
|
246 |
+
if self.diff_ops is not None:
|
247 |
+
out_np['lap_perm_c'] = self.diff_ops.splu.perm_c
|
248 |
+
out_np['lap_perm_r'] = self.diff_ops.splu.perm_r
|
249 |
+
out_np['w'] =self.diff_ops.frames
|
250 |
+
out_npz['new_grad'] = self.diff_ops.grad.to_coo()
|
251 |
+
out_npz['new_rhs'] = self.diff_ops.rhs
|
252 |
+
out_npz['lap_L'] = self.diff_ops.splu.L
|
253 |
+
out_npz['lap_U'] = self.diff_ops.splu.U
|
254 |
+
out_npz['lap'] = self.diff_ops.poisson.lap
|
255 |
+
return {key: value for key, value in out_np.items() if value is not None}, {key: value for key, value in out_npz.items() if value is not None}
|
256 |
+
|
257 |
+
def get_data(self, key,file_type = 'npy'):
|
258 |
+
if key == 'samples':
|
259 |
+
return self.get_samples().xyz
|
260 |
+
elif key == "samples_normals":
|
261 |
+
return self.get_samples().normals
|
262 |
+
elif key == "samples_wks":
|
263 |
+
return self.get_samples().wks
|
264 |
+
elif key == 'vertices':
|
265 |
+
return self.vertices
|
266 |
+
elif key == 'faces':
|
267 |
+
return self.faces
|
268 |
+
if file_type == 'npy':
|
269 |
+
return np.load(os.path.join(self.source_dir, f'{key}.npy'))
|
270 |
+
elif file_type == 'npz':
|
271 |
+
return load_npz(os.path.join(self.source_dir, f'{key}.npz'))
|
272 |
+
else:
|
273 |
+
raise RuntimeError("wrong file type")
|
274 |
+
|
275 |
+
def computeWKS(self):
|
276 |
+
if self.faces_wks is None or self.vert_wks is None:
|
277 |
+
st = time()
|
278 |
+
w = WaveKernelSignature(self.vertices, self.faces, top_k_eig=50)
|
279 |
+
w.compute()
|
280 |
+
print(f"Ellapsed {time() - st}")
|
281 |
+
wk = w.wks
|
282 |
+
faces_wks = np.zeros((self.faces.shape[0], wk.shape[1]))
|
283 |
+
for i in range(3):
|
284 |
+
faces_wks += wk[self.faces[:, i], :]
|
285 |
+
faces_wks /= 3
|
286 |
+
self.faces_wks = faces_wks
|
287 |
+
self.vert_wks = wk
|
288 |
+
assert (self.faces_wks.shape[0] == self.faces.shape[0])
|
289 |
+
assert (self.vert_wks.shape[0] == self.vertices.shape[0])
|
290 |
+
|
291 |
+
|
292 |
+
def sample_points(self, n):
|
293 |
+
bary, found_faces = igl.random_points_on_mesh(n, self.vertices, self.faces)
|
294 |
+
vert_ind = self.faces[found_faces]
|
295 |
+
point_samples = self.vertices[vert_ind[:,0]] * bary[:,0:1] + self.vertices[vert_ind[:,1]] * bary[:,1:2] + self.vertices[vert_ind[:,2]] * bary[:,2:3]
|
296 |
+
normal_samples = self.normals[vert_ind[:,0]] * bary[:,0:1] + self.normals[vert_ind[:,1]] * bary[:,1:2] + self.normals[vert_ind[:,2]] * bary[:,2:3]
|
297 |
+
wks_samples = None
|
298 |
+
if self.load_wks_centroids or self.load_wks_samples:
|
299 |
+
wks_samples = self.vert_wks[vert_ind[:,0]] * bary[:,0:1] + self.vert_wks[vert_ind[:,1]] * bary[:,1:2] + self.vert_wks[vert_ind[:,2]] * bary[:,2:3]
|
300 |
+
return point_samples, normal_samples, wks_samples, bary
|
301 |
+
|
302 |
+
# This is insane to me
|
303 |
+
def sample_points(V, F, n):
|
304 |
+
'''
|
305 |
+
samples n points on the given mesh, along with normals and wks. Also return WKS of original faces (by averaging wks of 3 vertices of each face)
|
306 |
+
:return:
|
307 |
+
'''
|
308 |
+
newF = F
|
309 |
+
newV = V
|
310 |
+
for iter in range(n):
|
311 |
+
newV, newF = _sample_point(newV, newF)
|
312 |
+
|
313 |
+
w = WaveKernelSignature(newV, newF, top_k_eig=100)
|
314 |
+
w.compute()
|
315 |
+
wk = w.wks
|
316 |
+
sample_ks = wk[len(V):, :]
|
317 |
+
org_ks = wk[:len(V), :]
|
318 |
+
normals = igl.per_vertex_normals(newV, newF)
|
319 |
+
normals = normals[len(V):, :]
|
320 |
+
|
321 |
+
# get per-face wks by averaging its vertices
|
322 |
+
faces_wks = np.zeros((F.shape[0], org_ks.shape[1]))
|
323 |
+
for i in range(3):
|
324 |
+
faces_wks += org_ks[F[:, i], :]
|
325 |
+
faces_wks /= 3
|
326 |
+
return newV[len(V):, :], normals, sample_ks, faces_wks
|
327 |
+
|
328 |
+
|
329 |
+
def _sample_point(VV, FF):
|
330 |
+
while (True):
|
331 |
+
bary, found_faces = igl.random_points_on_mesh(1, VV, FF)
|
332 |
+
if (found_faces >= FF.shape[0]):
|
333 |
+
continue
|
334 |
+
# use to be 0.01
|
335 |
+
if not numpy.any(bary < 0.05):
|
336 |
+
break
|
337 |
+
ret = numpy.zeros((1, VV.shape[1]))
|
338 |
+
for i in range(VV.shape[1]):
|
339 |
+
res = np.multiply(VV[FF[found_faces, :], i], bary)
|
340 |
+
ret[:, i] = np.sum(res)
|
341 |
+
newF = FF
|
342 |
+
new_index = len(VV)
|
343 |
+
new_tris = _insert_triangle(FF[found_faces, :], new_index)
|
344 |
+
newF = numpy.concatenate((newF, new_tris), axis=0)
|
345 |
+
newF = numpy.delete(newF, found_faces, axis=0)
|
346 |
+
newV = numpy.concatenate((VV, ret), 0)
|
347 |
+
return newV, newF
|
348 |
+
|
349 |
+
|
350 |
+
def _insert_triangle(old_tri, new_index):
|
351 |
+
d = new_index
|
352 |
+
a, b, c = (old_tri[0], old_tri[1], old_tri[2])
|
353 |
+
new_tris = numpy.array([[a, b, d], [b, c, d], [c, a, d]])
|
354 |
+
return new_tris
|
355 |
+
|
356 |
+
|
357 |
+
|
358 |
+
|
359 |
+
|
360 |
+
class WaveKernelSignatureError(Exception):
|
361 |
+
pass
|
362 |
+
|
363 |
+
class WaveKernelSignature:
|
364 |
+
'''
|
365 |
+
Computes wave kernel signature for a given mesh
|
366 |
+
'''
|
367 |
+
def __init__(self,
|
368 |
+
vertices,
|
369 |
+
faces,
|
370 |
+
top_k_eig=200,
|
371 |
+
timestamps=WKS_DIM):
|
372 |
+
# vertices, faces are both numpy arrays.
|
373 |
+
self.vertices = vertices
|
374 |
+
self.faces = faces
|
375 |
+
|
376 |
+
# self.vertices_gpu = torch.from_numpy(vertices).cuda()
|
377 |
+
# self.faces_gpu = torch.from_numpy(faces).cuda()
|
378 |
+
|
379 |
+
self.top_k_eig = top_k_eig
|
380 |
+
self.timestamps = timestamps
|
381 |
+
self.max_iter = 10000
|
382 |
+
|
383 |
+
def compute(self):
|
384 |
+
'''
|
385 |
+
compute the wks. Afterwards WKS stores in self.wks
|
386 |
+
'''
|
387 |
+
cp = igl.connected_components(igl.adjacency_matrix(self.faces))
|
388 |
+
assert(cp[0]==1), f"{cp}"
|
389 |
+
L = -igl.cotmatrix(self.vertices, self.faces) # this is fast 0.04 seconds
|
390 |
+
M = igl.massmatrix(self.vertices, self.faces, igl.MASSMATRIX_TYPE_VORONOI)
|
391 |
+
# assert(not numpy.any(numpy.isinf(L)))
|
392 |
+
try:
|
393 |
+
try:
|
394 |
+
self.eig_vals, self.eig_vecs = scipy.sparse.linalg.eigsh(
|
395 |
+
L, self.top_k_eig, M, sigma=0, which='LM', maxiter=self.max_iter)
|
396 |
+
except:
|
397 |
+
self.eig_vals, self.eig_vecs = scipy.sparse.linalg.eigsh(
|
398 |
+
L, self.top_k_eig, M, sigma=1e-4, which='LM', maxiter=self.max_iter)
|
399 |
+
except:
|
400 |
+
raise WaveKernelSignatureError("Error in computing WKS")
|
401 |
+
|
402 |
+
# print(np.linalg.norm(self.eig_vecs, axis=0, keepdims=True))
|
403 |
+
# print(np.max(self.eig_vecs))
|
404 |
+
self.eig_vecs /= 200 #np.linalg.norm(self.eig_vecs, axis=0, keepdims=True)
|
405 |
+
# np.save("norm_v2.npy", np.max(np.abs(self.eig_vecs), axis=0, keepdims=True))
|
406 |
+
# np.save("norm_v2.npy", np.max(np.abs(self.eig_vecs), axis=0, keepdims=True))
|
407 |
+
# print(np.linalg.norm(self.eig_vecs, axis=0))
|
408 |
+
# print(np.max(self.eig_vecs, axis=0))
|
409 |
+
# print(np.min(self.eig_vecs, axis=0))
|
410 |
+
# print(self.eig_vals)
|
411 |
+
|
412 |
+
# self.eig_vecs /= np.load('norm_v1.npy')
|
413 |
+
|
414 |
+
# self.eig_vecs = self.eig_vecs / np.max(np.abs(self.eig_vecs), axis=0, keepdims=True)
|
415 |
+
# self.eig_vecs = self.eig_vecs * np.load('norm_v2.npy')
|
416 |
+
|
417 |
+
|
418 |
+
# nn = np.load('norm2.npy')
|
419 |
+
# self.eig_vecs /= nn[:,:50]
|
420 |
+
|
421 |
+
|
422 |
+
# ==== VISUALIZATION CODE ==========
|
423 |
+
if False:
|
424 |
+
num_mesh_to_viz = 6
|
425 |
+
meshes = []
|
426 |
+
for i in range(num_mesh_to_viz):
|
427 |
+
meshes.append(trimesh.Trimesh(self.vertices + np.array([i*1,0,0]), self.faces, process=False))
|
428 |
+
|
429 |
+
# mesh = meshes[0].union( meshes[1])
|
430 |
+
# mesh = mesh.union( meshes[2])
|
431 |
+
# mesh = mesh.union( meshes[3])
|
432 |
+
meshes = [trimesh.util.concatenate(meshes)]
|
433 |
+
|
434 |
+
from vedo import trimesh2vedo, show, screenshot, Plotter
|
435 |
+
|
436 |
+
vp = Plotter(axes=0, offscreen=True)
|
437 |
+
|
438 |
+
vmeshes = trimesh2vedo(meshes)
|
439 |
+
cmaps = ('jet', 'PuOr', 'viridis')
|
440 |
+
scals = self.eig_vecs[:,:num_mesh_to_viz].transpose((1,0)).reshape(-1)
|
441 |
+
vmeshes[0].cmap(cmaps[0], scals).lighting('plastic')
|
442 |
+
|
443 |
+
# add a 2D scalar bar to a mesh
|
444 |
+
vmeshes[0].addScalarBar(title=f"scalarbar #{0}", c='k')
|
445 |
+
|
446 |
+
vp.show(vmeshes, axes=1)
|
447 |
+
screenshot(f"test_{time()}.png")
|
448 |
+
import sys
|
449 |
+
sys.exit(0)
|
450 |
+
# ================
|
451 |
+
|
452 |
+
|
453 |
+
|
454 |
+
|
455 |
+
# range between biggest and smallest eigenvalue :
|
456 |
+
# 6 0.09622419080119388
|
457 |
+
# 6_bis 0.09651935545457718
|
458 |
+
delta = (np.log(self.eig_vals[-1]) - np.log(self.eig_vals[1])) / self.timestamps
|
459 |
+
sigma = 7 * delta
|
460 |
+
e_min = np.log(self.eig_vals[1]) + 2 * delta
|
461 |
+
e_max = np.log(self.eig_vals[-1]) - 2 * delta
|
462 |
+
es = np.linspace(e_min, e_max, self.timestamps) # T
|
463 |
+
self.delta = delta
|
464 |
+
|
465 |
+
|
466 |
+
coef = np.expand_dims(es, 0) - np.expand_dims(np.log(self.eig_vals[1:]), 1) # (K-1)xT
|
467 |
+
coef = np.exp(-np.square(coef) / (2 * sigma * sigma)) # (K-1)xT #element wise square
|
468 |
+
sum_coef = coef.sum(0) # T
|
469 |
+
K = np.matmul(np.square(self.eig_vecs[:, 1:]), coef) # VxT. Scaling of the eigen vectors by coef. Coef depends only on the eigen values. Triangulation agnostic.
|
470 |
+
self.wks = K / np.expand_dims(sum_coef, 0) # VxT Scaling of the eigen vectors by sum_coef. Coef depends only on the eigen values. Triangulation agnostic.
|
471 |
+
# print(np.linalg.norm(self.wks, axis=0))
|
472 |
+
# print(np.linalg.norm(self.wks, axis=1))
|
473 |
+
|
NeuralJacobianFields/PoissonSystem.py
ADDED
@@ -0,0 +1,796 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy
|
2 |
+
import igl
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import time
|
6 |
+
|
7 |
+
from scipy.sparse import diags,coo_matrix
|
8 |
+
from scipy.sparse import csc_matrix as sp_csc
|
9 |
+
|
10 |
+
|
11 |
+
USE_TORCH_SPARSE = True ## This uses TORCH_SPARSE instead of TORCH.SPARSE
|
12 |
+
|
13 |
+
# This four are mutually exclusive
|
14 |
+
USE_CUPY = False ## This uses CUPY LU decomposition on GPU
|
15 |
+
USE_CHOLESPY_GPU = True ## This uses cholesky decomposition on GPU
|
16 |
+
USE_CHOLESPY_CPU = False ## This uses cholesky decomposition on CPU
|
17 |
+
USE_SCIPY = False ## This uses CUPY LU decomposition on CPU
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
# If USE_SCIPY = True, wether or not to use enhanced backend
|
22 |
+
USE_SCIKITS_UMFPACK = False ## This uses UMFPACK backend for scipy instead of naive scipy.
|
23 |
+
|
24 |
+
if USE_CHOLESPY_GPU or USE_CHOLESPY_CPU:
|
25 |
+
from cholespy import CholeskySolverD, MatrixType
|
26 |
+
|
27 |
+
if USE_CUPY and torch.cuda.is_available():
|
28 |
+
from cupyx.scipy.sparse.linalg import spsolve_triangular
|
29 |
+
from cupyx.scipy.sparse import csr_matrix
|
30 |
+
import cupy
|
31 |
+
from torch.utils.dlpack import to_dlpack, from_dlpack
|
32 |
+
|
33 |
+
from scipy.sparse.linalg import splu as scipy_splu
|
34 |
+
from scipy.sparse.linalg import spsolve_triangular, spsolve
|
35 |
+
if USE_SCIPY:
|
36 |
+
if USE_SCIKITS_UMFPACK:
|
37 |
+
# This is a bit slower in practice
|
38 |
+
# https://stackoverflow.com/questions/64401503/is-there-a-way-to-further-improve-sparse-solution-times-using-python
|
39 |
+
from scikits.umfpack import splu as scipy_splu
|
40 |
+
else:
|
41 |
+
import scipy.sparse.linalg as lg
|
42 |
+
lg.use_solver(useUmfpack=False)
|
43 |
+
# Slight performance gain with True
|
44 |
+
# conda install -c conda-forge scikit-umfpack
|
45 |
+
# forward pass goes from 0.038 to 0.036
|
46 |
+
# assumeSortedIndices=True Does not bring any boost
|
47 |
+
from scipy.sparse.linalg import splu as scipy_splu
|
48 |
+
from scipy.sparse.linalg import spsolve_triangular, spsolve
|
49 |
+
|
50 |
+
|
51 |
+
if USE_TORCH_SPARSE:
|
52 |
+
import torch_sparse
|
53 |
+
|
54 |
+
|
55 |
+
USE_UGLY_PATCH_FOR_CUPY_ERROR = False
|
56 |
+
|
57 |
+
|
58 |
+
class SparseMat:
|
59 |
+
'''
|
60 |
+
Sparse matrix object represented in the COO format
|
61 |
+
Refacto : consider killing this object, byproduct of torch_sparse instead of torch.sparse (new feature)
|
62 |
+
'''
|
63 |
+
|
64 |
+
@staticmethod
|
65 |
+
def from_M(M,ttype):
|
66 |
+
return SparseMat(M[0],M[1],M[2],M[3],ttype)
|
67 |
+
|
68 |
+
@staticmethod
|
69 |
+
def from_coo(coo,ttype):
|
70 |
+
inds = numpy.vstack((coo.row,coo.col))
|
71 |
+
return SparseMat(inds,coo.data,coo.shape[0],coo.shape[1],ttype)
|
72 |
+
|
73 |
+
def __init__(self,inds,vals,n,m,ttype):
|
74 |
+
self.n = n
|
75 |
+
self.m = m
|
76 |
+
self.vals = vals
|
77 |
+
self.inds = inds
|
78 |
+
assert(inds.shape[0] == 2)
|
79 |
+
assert(inds.shape[1] == vals.shape[0])
|
80 |
+
assert(np.max(inds[0,:]) <= n)
|
81 |
+
assert(np.max(inds[1,:] <= m))
|
82 |
+
#TODO figure out how to extract the I,J,V,m,n from this, then load a COO mat directly from npz
|
83 |
+
#self.coo_mat = coo_matrix((cupy.array(self.vals), (cupy.array(self.inds[0,:]), cupy.array(self.inds[1,:]))))
|
84 |
+
self.vals = torch.from_numpy(self.vals).type(ttype).contiguous()
|
85 |
+
self.inds = torch.from_numpy(self.inds).type(torch.int64).contiguous()
|
86 |
+
|
87 |
+
def to_coo(self):
|
88 |
+
return coo_matrix((self.vals, (self.inds[0,:], self.inds[1,:])), shape = (self.n, self.m))
|
89 |
+
|
90 |
+
def to_csc(self):
|
91 |
+
return sp_csc((self.vals, (self.inds[0,:], self.inds[1,:])), shape = (self.n, self.m))
|
92 |
+
|
93 |
+
def to_cholesky(self):
|
94 |
+
return CholeskySolverD(self.n, self.inds[0,:], self.inds[1,:], self.vals, MatrixType.COO)
|
95 |
+
|
96 |
+
def to(self,device):
|
97 |
+
self.vals = self.vals.to(device)
|
98 |
+
self.inds = self.inds.to(device)
|
99 |
+
return self
|
100 |
+
|
101 |
+
def pin_memory(self):
|
102 |
+
return
|
103 |
+
# self.vals.pin_memory()
|
104 |
+
# self.inds.pin_memory()
|
105 |
+
|
106 |
+
def multiply_with_dense(self,dense):
|
107 |
+
if USE_TORCH_SPARSE:
|
108 |
+
res = torch_sparse.spmm(self.inds,self.vals, self.n, self.m, dense)
|
109 |
+
# 1000 for loop on the above line takes 0.13 sec. Fast but annoying to have this dependency
|
110 |
+
else:
|
111 |
+
# Somehow this is not implemented for now?
|
112 |
+
# res = torch.smm(torch.sparse_coo_tensor(self.inds,self.vals) , (dense.float())).to_dense().to(dense.device)
|
113 |
+
# 1000 for loop on the above line takes 10 sec on the CPU. It is not implemented on gpu yet Slower but no dependency
|
114 |
+
if self.vals.device.type == 'cpu':
|
115 |
+
tensor_zero_hack = torch.FloatTensor([0]).double() # This line was somehow responsible for a nasty NAN bug
|
116 |
+
else:
|
117 |
+
tensor_zero_hack = torch.cuda.FloatTensor([0]).to(dense.get_device()).double()
|
118 |
+
# beware with addmm, it is experimental and gave me a NaN bug!
|
119 |
+
res = torch.sparse.addmm(tensor_zero_hack, torch.sparse_coo_tensor(self.inds.double(),self.vals.double()) , (dense.double())).type_as(self.vals)
|
120 |
+
# 1000 for loop on the above line takes 0.77 sec. Slower but no dependency
|
121 |
+
return res.contiguous()
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
class PoissonSystemMatrices:
|
126 |
+
'''
|
127 |
+
Holds the matrices needed to perform gradient and poisson computations
|
128 |
+
Logic : this class is supposed is supposed to hold everything needed to compute Poisson Solver
|
129 |
+
Refacto : merge with Poisson Solver
|
130 |
+
Only accept SparseMat representation
|
131 |
+
'''
|
132 |
+
def __init__(self, V, F,grad, rhs, w, ttype, is_sparse = True, lap = None, cpuonly=False):
|
133 |
+
self.dim = 3
|
134 |
+
self.is_sparse = is_sparse
|
135 |
+
self.w = w
|
136 |
+
self.rhs = rhs
|
137 |
+
self.igl_grad = grad
|
138 |
+
self.ttype = ttype
|
139 |
+
self.__splu_L = None
|
140 |
+
self.__splu_U = None
|
141 |
+
self.__splu_perm_c = None
|
142 |
+
self.__splu_perm_r = None
|
143 |
+
self.lap = lap
|
144 |
+
self.__V = V
|
145 |
+
self.__F = F
|
146 |
+
self.cpuonly = cpuonly
|
147 |
+
self.cpu_splu = None
|
148 |
+
|
149 |
+
|
150 |
+
def create_poisson_solver(self):
|
151 |
+
return PoissonSolver(self.igl_grad,self.w,self.rhs, None, self.lap)
|
152 |
+
|
153 |
+
def create_poisson_solver_from_splu_old(self, lap_L, lap_U, lap_perm_c, lap_perm_r):
|
154 |
+
w = torch.from_numpy(self.w).type(self.ttype)
|
155 |
+
lap = None
|
156 |
+
my_splu = None
|
157 |
+
if not self.cpuonly:
|
158 |
+
if USE_CUPY:
|
159 |
+
my_splu = MyCuSPLU(lap_L, lap_U, lap_perm_c, lap_perm_r)
|
160 |
+
else:
|
161 |
+
if self.lap is not None:
|
162 |
+
lap = self.lap
|
163 |
+
# my_splu = scipy_splu(self.lap)
|
164 |
+
# my_splu = MyCuSPLU_CPU(lap_L, lap_U, lap_perm_c, lap_perm_r)
|
165 |
+
else:
|
166 |
+
my_splu = MyCuSPLU_CPU(lap_L, lap_U, lap_perm_c, lap_perm_r)
|
167 |
+
# st = time.time()
|
168 |
+
# my_splu = scipy_splu(lap_L@lap_U)
|
169 |
+
# print(f"time for LU: {time.time() - st}" )
|
170 |
+
|
171 |
+
else:
|
172 |
+
if self.lap is not None:
|
173 |
+
my_splu = scipy_splu(self.lap)
|
174 |
+
else:
|
175 |
+
0/0
|
176 |
+
# my_splu = splu(lap_L)
|
177 |
+
|
178 |
+
return PoissonSolver(self.igl_grad,w,self.rhs,my_splu, lap)
|
179 |
+
|
180 |
+
def compute_poisson_solver_from_laplacian(self, compute_splu=True):
|
181 |
+
self.compute_laplacian()
|
182 |
+
if compute_splu:
|
183 |
+
self.compute_splu()
|
184 |
+
return self.create_poisson_solver_from_splu(self.__splu_L,self.__splu_U,self.__splu_perm_c,self.__splu_perm_r)
|
185 |
+
|
186 |
+
def compute_laplacian(self):
|
187 |
+
if self.lap is None:
|
188 |
+
self.lap = igl.cotmatrix(self.__V,self.__F)
|
189 |
+
self.lap = self.lap[1:, 1:]
|
190 |
+
self.lap = SparseMat.from_coo(self.lap.tocoo(), torch.float64)
|
191 |
+
|
192 |
+
if isinstance(self.lap,PoissonSystemMatrices) and self.lap.vals.shape[0] == self.__V.shape[0]:
|
193 |
+
assert(False), "this should not happen, the fix is to remove a column and row of the laplacian"
|
194 |
+
self.lap = self.lap[1:, 1:]
|
195 |
+
|
196 |
+
return self.lap
|
197 |
+
|
198 |
+
def compute_splu(self):
|
199 |
+
print("i am computing splu")
|
200 |
+
if self.cpu_splu is None:
|
201 |
+
# st = time.time()
|
202 |
+
s = scipy_splu(self.lap)
|
203 |
+
# print(f"time to compute LU {time.time() - st}")
|
204 |
+
# We are storing these attributes just in case we need to create a PoissonSolver on the GPU, they are useless for CPU case.
|
205 |
+
self.cpu_splu = s
|
206 |
+
self.__splu_L = s.L
|
207 |
+
self.__splu_U = s.U
|
208 |
+
self.__splu_perm_c = s.perm_c
|
209 |
+
self.__splu_perm_r = s.perm_r
|
210 |
+
return self.__splu_L,self.__splu_U,self.__splu_perm_c,self.__splu_perm_r
|
211 |
+
|
212 |
+
def get_new_grad(self):
|
213 |
+
grad = self.igl_grad.to_coo()
|
214 |
+
self.igl_grad = SparseMat.from_M(_convert_sparse_igl_grad_to_our_convention(grad.tocsc()),torch.float64)
|
215 |
+
return self.igl_grad
|
216 |
+
|
217 |
+
def _convert_sparse_igl_grad_to_our_convention(input):
|
218 |
+
'''
|
219 |
+
The grad operator computed from igl.grad() results in a matrix of shape (3*#tri x #verts).
|
220 |
+
It is packed such that all the x-coordinates are placed first, followed by y and z. As shown below
|
221 |
+
|
222 |
+
---------- ----------
|
223 |
+
| x1 ... | x1 ...
|
224 |
+
| x2 ... | y1 ...
|
225 |
+
| x3 ... | z1 ...
|
226 |
+
| . | .
|
227 |
+
| . | .
|
228 |
+
| y1 ... | x2 ...
|
229 |
+
| y2 ... ----> | y2 ...
|
230 |
+
| y3 ... | z2 ...
|
231 |
+
| . | .
|
232 |
+
| . | .
|
233 |
+
| z1 ... | x3 ...
|
234 |
+
| z2 ... | y3 ...
|
235 |
+
| z3 ... | z3 ...
|
236 |
+
| . | .
|
237 |
+
| . | .
|
238 |
+
---------- ----------
|
239 |
+
|
240 |
+
Note that this functionality cannot be computed trivially if because igl.grad() is a sparse tensor and as such
|
241 |
+
slicing is not well defined for sparse matrices. the following code performs the above conversion and returns a
|
242 |
+
torch.sparse tensor.
|
243 |
+
Set check to True to verify the results by converting the matrices to dense and comparing it.
|
244 |
+
'''
|
245 |
+
assert type(input) == sp_csc, 'Input should be a scipy csc sparse matrix'
|
246 |
+
T = input.tocoo()
|
247 |
+
|
248 |
+
r_c_data = np.hstack((T.row[..., np.newaxis], T.col[..., np.newaxis],
|
249 |
+
T.data[..., np.newaxis])) # horizontally stack row, col and data arrays
|
250 |
+
r_c_data = r_c_data[r_c_data[:, 0].argsort()] # sort along the row column
|
251 |
+
|
252 |
+
# Separate out x, y and z blocks
|
253 |
+
'''
|
254 |
+
Note that for the grad operator there are exactly 3 non zero elements in a row
|
255 |
+
'''
|
256 |
+
L = T.shape[0]
|
257 |
+
Tx = r_c_data[:L, :]
|
258 |
+
Ty = r_c_data[L:2 * L, :]
|
259 |
+
Tz = r_c_data[2 * L:3 * L, :]
|
260 |
+
|
261 |
+
# align the y,z rows with x so that they too start from 0
|
262 |
+
Ty[:, 0] -= Ty[0, 0]
|
263 |
+
Tz[:, 0] -= Tz[0, 0]
|
264 |
+
|
265 |
+
# 'strech' the x,y,z rows so that they can be interleaved.
|
266 |
+
Tx[:, 0] *= 3
|
267 |
+
Ty[:, 0] *= 3
|
268 |
+
Tz[:, 0] *= 3
|
269 |
+
|
270 |
+
# interleave the y,z into x
|
271 |
+
Ty[:, 0] += 1
|
272 |
+
Tz[:, 0] += 2
|
273 |
+
|
274 |
+
Tc = np.zeros((input.shape[0] * 3, 3))
|
275 |
+
Tc[::3] = Tx
|
276 |
+
Tc[1::3] = Ty
|
277 |
+
Tc[2::3] = Tz
|
278 |
+
|
279 |
+
indices = Tc[:, :-1].astype(int)
|
280 |
+
data = Tc[:, -1]
|
281 |
+
|
282 |
+
return (indices.T, data, input.shape[0], input.shape[1])
|
283 |
+
|
284 |
+
|
285 |
+
class PoissonSolver:
|
286 |
+
'''
|
287 |
+
an object to compute gradients and solve poisson
|
288 |
+
'''
|
289 |
+
|
290 |
+
def __init__(self,grad,W,rhs,my_splu, lap=None):
|
291 |
+
self.W = torch.from_numpy(W).double()
|
292 |
+
self.grad = grad
|
293 |
+
self.rhs = rhs
|
294 |
+
self.my_splu = my_splu
|
295 |
+
self.lap = lap
|
296 |
+
self.sparse_grad = grad
|
297 |
+
self.sparse_rhs = rhs
|
298 |
+
|
299 |
+
def to(self,device):
|
300 |
+
self.W = self.W.to(device)
|
301 |
+
self.sparse_grad = self.sparse_grad.to(device)
|
302 |
+
self.sparse_rhs = self.sparse_rhs.to(device)
|
303 |
+
if USE_CUPY or USE_CHOLESPY_GPU:
|
304 |
+
self.lap = self.lap.to(device)
|
305 |
+
return self
|
306 |
+
|
307 |
+
def jacobians_from_vertices(self,V):
|
308 |
+
res = _multiply_sparse_2d_by_dense_3d(self.sparse_grad, V).type_as(V)
|
309 |
+
res = res.unsqueeze(2)
|
310 |
+
return res.view(V.shape[0], -1, 3,3).transpose(2,3)
|
311 |
+
|
312 |
+
def restrict_jacobians(self,D):
|
313 |
+
assert isinstance(D, torch.Tensor) and len(D.shape) in [3, 4]
|
314 |
+
assert D.shape[-1] == 3 and D.shape[-2] == 3
|
315 |
+
assert isinstance(self.W, torch.Tensor) and len(self.W.shape) == 3
|
316 |
+
assert self.W.shape[-1] == 2 and self.W.shape[-2] == 3
|
317 |
+
|
318 |
+
if len(D.shape) == 4:
|
319 |
+
DW = torch.einsum("abcd,bde->abce", (D, self.W.type_as(D)))
|
320 |
+
else:
|
321 |
+
DW = torch.einsum("abcd,bde->abce", (D.unsqueeze(0), self.W)).squeeze(0)
|
322 |
+
|
323 |
+
if len(DW.shape)>4:
|
324 |
+
DW = DW.squeeze(0)
|
325 |
+
return DW
|
326 |
+
|
327 |
+
def restricted_jacobians_from_vertices(self,V):
|
328 |
+
return self.restrict_jacobians(self.jacobians_from_vertices(V))
|
329 |
+
|
330 |
+
def solve_poisson(self,jacobians):
|
331 |
+
# st = time.time()
|
332 |
+
assert(len(jacobians.shape) == 4)
|
333 |
+
assert(jacobians.shape[2] == 3 and jacobians.shape[3] == 3)
|
334 |
+
|
335 |
+
# torch.cuda.synchronize()
|
336 |
+
# st = time.time()
|
337 |
+
|
338 |
+
if self.my_splu is None:
|
339 |
+
if isinstance(self.lap,SparseMat):
|
340 |
+
# self.my_splu = scipy_splu(self.lap.to('cpu').to_coo())
|
341 |
+
if USE_CHOLESPY_CPU or USE_CHOLESPY_GPU:
|
342 |
+
self.my_splu = self.lap.to_cholesky()
|
343 |
+
else:
|
344 |
+
self.my_splu = scipy_splu(self.lap.to('cpu').to_coo())
|
345 |
+
else:
|
346 |
+
self.my_splu = scipy_splu(self.lap)
|
347 |
+
|
348 |
+
# print(f"computing poisson! {self.lap.vals.get_device()}")
|
349 |
+
# print(f"computing poisson! {self.lap.inds.get_device()}")
|
350 |
+
# print(f"computing poisson! {jacobians.get_device()}")
|
351 |
+
# print(f"computing poisson! {self.sparse_rhs.vals.get_device()}")
|
352 |
+
# torch.cuda.synchronize()
|
353 |
+
# print(f"SOLVER decomposition {time.time() - st}")
|
354 |
+
|
355 |
+
sol = _predicted_jacobians_to_vertices_via_poisson_solve(self.my_splu, self.sparse_rhs, jacobians.transpose(2, 3).reshape(jacobians.shape[0], -1, 3, 1).squeeze(3).contiguous())
|
356 |
+
# torch.cuda.synchronize()
|
357 |
+
# print(f"POISSON LU + SOLVE FORWARD{time.time() - st}")
|
358 |
+
c = torch.mean(sol, axis=1).unsqueeze(1) ## Beware the predicted mesh is centered here.
|
359 |
+
# print(f"time for poisson: {time.time() - st}" )
|
360 |
+
return sol - c
|
361 |
+
|
362 |
+
def pin_memory(self):
|
363 |
+
return
|
364 |
+
# self.W.pin_memory()
|
365 |
+
# self.sparse_grad.pin_memory()
|
366 |
+
# self.sparse_rhs.pin_memory()
|
367 |
+
|
368 |
+
|
369 |
+
def poisson_system_matrices_from_mesh( V,F, dim=3,ttype = torch.float64, is_sparse=True,cpuonly=False):
|
370 |
+
'''
|
371 |
+
compute poisson matricees for a given mesh
|
372 |
+
:param V vertices
|
373 |
+
:param F faces
|
374 |
+
:param dim: for now always 3 :)
|
375 |
+
:param ttype the type of tensor (e.g., float,double)
|
376 |
+
:param is_sparse: for now always true
|
377 |
+
:return: a PoissonMatricese object holding the computed matrices
|
378 |
+
'''
|
379 |
+
|
380 |
+
assert type(dim) == int and dim in [2,3], f'Only two and three dimensional meshes are supported'
|
381 |
+
assert type(is_sparse) == bool
|
382 |
+
vertices = V
|
383 |
+
faces = F
|
384 |
+
dim = 3
|
385 |
+
is_sparse = is_sparse
|
386 |
+
|
387 |
+
grad = igl.grad(vertices, faces)
|
388 |
+
# grad = np.abs(grad)
|
389 |
+
# temp_grad = grad.multiply(csr_matrix(1 / np.sqrt(grad.multiply(grad).sum(1))))
|
390 |
+
# gradients_normalized = grad / np.linalg.norm(grad, axis=1)[:, np.newaxis]
|
391 |
+
|
392 |
+
mass = _get_mass_matrix(vertices,faces,is_sparse)
|
393 |
+
## TODO 2D Case ##
|
394 |
+
if dim == 2:
|
395 |
+
grad = grad[:-grad.shape[0]//3,:]
|
396 |
+
mass = mass[:-mass.shape[0]//3,:-mass.shape[0]//3]
|
397 |
+
|
398 |
+
laplace = grad.T@mass@grad
|
399 |
+
laplace = laplace[1:, 1:]
|
400 |
+
|
401 |
+
rhs = grad.T@mass
|
402 |
+
b1,b2,_ = igl.local_basis(V,F)
|
403 |
+
w = np.stack((b1,b2),axis=-1)
|
404 |
+
# print(time.time() - s)
|
405 |
+
|
406 |
+
rhs = rhs[1:,:]
|
407 |
+
|
408 |
+
if is_sparse:
|
409 |
+
laplace = laplace.tocoo()
|
410 |
+
rhs = rhs.tocoo()
|
411 |
+
grad = grad.tocsc()
|
412 |
+
else:
|
413 |
+
laplace = laplace.toarray()
|
414 |
+
rhs = rhs.toarray()
|
415 |
+
grad = grad.toarray()
|
416 |
+
|
417 |
+
|
418 |
+
grad = SparseMat.from_M(_convert_sparse_igl_grad_to_our_convention(grad), torch.float64)
|
419 |
+
poissonbuilder = PoissonSystemMatrices(V=V,F=F,grad=grad,
|
420 |
+
rhs=SparseMat.from_coo(rhs, torch.float64), w=w,
|
421 |
+
ttype=ttype,is_sparse=is_sparse,
|
422 |
+
lap=SparseMat.from_coo(laplace, torch.float64),
|
423 |
+
cpuonly=cpuonly)
|
424 |
+
# poissonbuilder.get_new_grad()
|
425 |
+
return poissonbuilder
|
426 |
+
|
427 |
+
def _get_mass_matrix(vertices,faces,is_sparse):
|
428 |
+
|
429 |
+
d_area = igl.doublearea(vertices,faces)
|
430 |
+
d_area = np.hstack((d_area, d_area, d_area))
|
431 |
+
if is_sparse:
|
432 |
+
return sp_csc(diags(d_area))
|
433 |
+
return diags(d_area)
|
434 |
+
|
435 |
+
|
436 |
+
|
437 |
+
|
438 |
+
class SPLUSolveLayer(torch.autograd.Function):
|
439 |
+
'''
|
440 |
+
Implements the SPLU solve as a differentiable layer, with a forward and backward function
|
441 |
+
'''
|
442 |
+
|
443 |
+
@staticmethod
|
444 |
+
def forward(ctx, solver, b):
|
445 |
+
'''
|
446 |
+
override forward function
|
447 |
+
:param ctx: context object (to keep the lu object for the backward pass)
|
448 |
+
:param lu: splu object
|
449 |
+
:param b: right hand side, could be a vector or matrix
|
450 |
+
:return: the vector or matrix x which holds lu.solve(b) = x
|
451 |
+
'''
|
452 |
+
assert isinstance(b, torch.Tensor)
|
453 |
+
assert b.shape[-1] >= 1 and b.shape[-1] <= 3, f'got shape {b.shape} expected last dim to be in range 1-3'
|
454 |
+
b = b.contiguous()
|
455 |
+
ctx.solver = solver
|
456 |
+
|
457 |
+
# st = time.time()
|
458 |
+
vertices = SPLUSolveLayer.solve(solver, b).type_as(b)
|
459 |
+
# print(f"FORWARD SOLVE {time.time() - st}")
|
460 |
+
|
461 |
+
assert not torch.isnan(vertices).any(), "Nan in the forward pass of the POISSON SOLVE"
|
462 |
+
return vertices
|
463 |
+
|
464 |
+
def backward(ctx, grad_output):
|
465 |
+
'''
|
466 |
+
overrides backward function
|
467 |
+
:param grad_output: the gradient to be back-propped
|
468 |
+
:return: the outgoing gradient to be back-propped
|
469 |
+
'''
|
470 |
+
|
471 |
+
assert isinstance(grad_output, torch.Tensor)
|
472 |
+
assert grad_output.shape[-1] >= 1 and grad_output.shape[
|
473 |
+
-1] <= 3, f'got shape {grad_output.shape} expected last dim to be in range 1-3'
|
474 |
+
# when backpropping, if a layer is linear with matrix M, x ---> Mx, then the backprop of gradient g is M^Tg
|
475 |
+
# in our case M = A^{-1}, so the backprop is to solve x = A^-T g.
|
476 |
+
# Because A is symmetric we simply solve A^{-1}g without transposing, but this will break if A is not symmetric.
|
477 |
+
# st = time.time()
|
478 |
+
grad_output = grad_output.contiguous()
|
479 |
+
grad = SPLUSolveLayer.solve(ctx.solver,
|
480 |
+
grad_output)
|
481 |
+
# print(f"BACKWARD SOLVE {time.time() - st}")
|
482 |
+
# At this point we perform a NAN check because the backsolve sometimes returns NaNs.
|
483 |
+
assert not torch.isnan(grad).any(), "Nan in the backward pass of the POISSON SOLVE"
|
484 |
+
|
485 |
+
if USE_CUPY:
|
486 |
+
mempool = cupy.get_default_memory_pool()
|
487 |
+
pinned_mempool = cupy.get_default_pinned_memory_pool()
|
488 |
+
mempool.free_all_blocks()
|
489 |
+
pinned_mempool.free_all_blocks()
|
490 |
+
del ctx.lu
|
491 |
+
|
492 |
+
return None, grad
|
493 |
+
|
494 |
+
@staticmethod
|
495 |
+
def solve(solver, b):
|
496 |
+
'''
|
497 |
+
solve the linear system defined by an SPLU object for a given right hand side. if the RHS is a matrix, solution will also be a matrix.
|
498 |
+
:param solver: the splu object (LU decomposition) or cholesky object
|
499 |
+
:param b: the right hand side to solve for
|
500 |
+
:return: solution x which satisfies Ax = b where A is the poisson system lu describes
|
501 |
+
'''
|
502 |
+
|
503 |
+
if USE_CUPY:
|
504 |
+
b_cupy = cupy.fromDlpack(to_dlpack(b))
|
505 |
+
with cupy.cuda.Device(solver.device()):
|
506 |
+
# this will hold the solution
|
507 |
+
sol = cupy.ndarray(b_cupy.shape)
|
508 |
+
for i in range(b_cupy.shape[2]): # b may have multiple columns, solve for each one
|
509 |
+
b2d = b_cupy[..., i] # cupy.expand_dims(b_cpu[...,i],2)
|
510 |
+
s = solver.solve(b2d.T).T
|
511 |
+
sol[:, :, i] = s
|
512 |
+
# # # convert back to torch
|
513 |
+
res = from_dlpack(sol.toDlpack())
|
514 |
+
# np.save("res_gpu.npy", res.cpu().numpy())
|
515 |
+
# res = torch.zeros((1, 6889, 3), device=b.device)+ torch.mean(b)
|
516 |
+
|
517 |
+
return res.type_as(b.type())
|
518 |
+
|
519 |
+
elif USE_SCIPY:
|
520 |
+
#only CPU
|
521 |
+
# st = time.time()
|
522 |
+
assert(b.shape[0]==1), "Need to code parrallel implem on the first dim"
|
523 |
+
sol = solver.solve(b[0].double().cpu().numpy())
|
524 |
+
res = torch.from_numpy(sol).to(b.device).reshape(b.shape)
|
525 |
+
# print(time.time() - st)
|
526 |
+
return res.type_as(b).contiguous()
|
527 |
+
|
528 |
+
# Legacy code, I don't understand what is the reason for having a for loop
|
529 |
+
# sol = np.ndarray(b.shape)
|
530 |
+
# for i in range(b.shape[2]): # b may have multiple columns, solve for each one
|
531 |
+
# b2d = b[..., i] # cupy.expand_dims(b_cpu[...,i],2)
|
532 |
+
# s = lu.solve(b2d.double().cpu().float().numpy().T).T
|
533 |
+
# sol[:, :, i] = s
|
534 |
+
# res = torch.from_numpy(sol).to(b.device)
|
535 |
+
# # np.save("res_cpu.npy", sol)
|
536 |
+
# print(f"time {time.time() - st}" )
|
537 |
+
elif USE_CHOLESPY_GPU:
|
538 |
+
# torch.cuda.synchronize()
|
539 |
+
# # st = time.time()
|
540 |
+
# assert(b.shape[0]==1), "Need to code parrallel implem on the first dim"
|
541 |
+
# b = b.squeeze().double()
|
542 |
+
# x = torch.zeros_like(b)
|
543 |
+
# solver.solve(b, x)
|
544 |
+
# # torch.cuda.synchronize()
|
545 |
+
# # print(f"time cholescky GPU {time.time() - st}" )
|
546 |
+
# return x.contiguous().unsqueeze(0)
|
547 |
+
# st = time.time()
|
548 |
+
# print(b.get_device(), b.shape)
|
549 |
+
b = b.double().contiguous()
|
550 |
+
c = b.permute(1,2,0).contiguous()
|
551 |
+
c = c.view(c.shape[0], -1)
|
552 |
+
x = torch.zeros_like(c)
|
553 |
+
solver.solve(c, x)
|
554 |
+
x = x.view(b.shape[1], b.shape[2], b.shape[0])
|
555 |
+
x = x.permute(2,0,1).contiguous()
|
556 |
+
# torch.cuda.synchronize()
|
557 |
+
# print(f"time cholescky GPU {time.time() - st}" )
|
558 |
+
return x.contiguous()
|
559 |
+
|
560 |
+
elif USE_CHOLESPY_CPU:
|
561 |
+
# st = time.time()
|
562 |
+
assert(b.shape[0]==1), "Need to code parrallel implem on the first dim"
|
563 |
+
b = b.squeeze()
|
564 |
+
b_cpu = b.cpu()
|
565 |
+
x = torch.zeros_like(b_cpu)
|
566 |
+
solver.solve(b_cpu, x)
|
567 |
+
# print(f"time cholescky CPU {time.time() - st}" )
|
568 |
+
return x.contiguous().to(b.device).unsqueeze(0)
|
569 |
+
|
570 |
+
|
571 |
+
return res.type_as(b)
|
572 |
+
|
573 |
+
def _predicted_jacobians_to_vertices_via_poisson_solve(Lap, rhs, jacobians):
|
574 |
+
'''
|
575 |
+
convert the predictions to the correct convention and feed it to the poisson solve
|
576 |
+
'''
|
577 |
+
|
578 |
+
def _batch_rearrange_input(input):
|
579 |
+
assert isinstance(input, torch.Tensor) and len(input.shape) in [2, 3]
|
580 |
+
P = torch.zeros(input.shape).type_as(input)
|
581 |
+
if len(input.shape) == 3:
|
582 |
+
# Batched input
|
583 |
+
k = input.shape[1] // 3
|
584 |
+
P[:, :k, :] = input[:, ::3]
|
585 |
+
P[:, k:2 * k, :] = input[:, 1::3]
|
586 |
+
P[:, 2 * k:, :] = input[:, 2::3]
|
587 |
+
|
588 |
+
else:
|
589 |
+
k = input.shape[0] // 3
|
590 |
+
P[:k, :] = input[::3]
|
591 |
+
P[k:2 * k, :] = input[1::3]
|
592 |
+
P[2 * k:, :] = input[2::3]
|
593 |
+
|
594 |
+
return P
|
595 |
+
|
596 |
+
def _list_rearrange_input(input):
|
597 |
+
assert isinstance(input, list) and all([isinstance(x, torch.Tensor) and len(x.shape) in [2, 3] for x in input])
|
598 |
+
P = []
|
599 |
+
for p in input:
|
600 |
+
P.append(_batch_rearrange_input(p))
|
601 |
+
return P
|
602 |
+
|
603 |
+
if isinstance(jacobians, list):
|
604 |
+
P = _list_rearrange_input(jacobians)
|
605 |
+
else:
|
606 |
+
P = _batch_rearrange_input(jacobians)
|
607 |
+
|
608 |
+
# return solve_poisson(Lap, rhs, P)
|
609 |
+
assert isinstance(P, torch.Tensor) and len(P.shape) in [2, 3]
|
610 |
+
assert len(P.shape) == 3
|
611 |
+
|
612 |
+
# torch.cuda.synchronize()
|
613 |
+
# st = time.time()
|
614 |
+
P = P.double()
|
615 |
+
input_to_solve = _multiply_sparse_2d_by_dense_3d(rhs, P)
|
616 |
+
|
617 |
+
|
618 |
+
out = SPLUSolveLayer.apply(Lap, input_to_solve)
|
619 |
+
|
620 |
+
out = torch.cat([torch.zeros(out.shape[0], 1, out.shape[2]).type_as(out), out], dim=1) ## Why?? Because!
|
621 |
+
out = out - torch.mean(out, axis=1, keepdim=True)
|
622 |
+
|
623 |
+
return out.type_as(jacobians)
|
624 |
+
|
625 |
+
|
626 |
+
|
627 |
+
def _multiply_sparse_2d_by_dense_3d(mat, B):
|
628 |
+
ret = []
|
629 |
+
for i in range(B.shape[0]):
|
630 |
+
C = mat.multiply_with_dense(B[i, ...])
|
631 |
+
ret.append(C)
|
632 |
+
ret = torch.stack(tuple(ret))
|
633 |
+
return ret
|
634 |
+
|
635 |
+
|
636 |
+
|
637 |
+
|
638 |
+
|
639 |
+
|
640 |
+
|
641 |
+
class MyCuSPLU:
|
642 |
+
'''
|
643 |
+
implmentation of SPLU on the gpu via CuPy
|
644 |
+
'''
|
645 |
+
def __init__(self, L, U, perm_c=None, perm_r=None):
|
646 |
+
# with cupy.cuda.Device(device):
|
647 |
+
self.__orgL = L
|
648 |
+
self.__orgU = U
|
649 |
+
# self.L = csr_matrix(L)
|
650 |
+
# self.U = csr_matrix(U)
|
651 |
+
self.L = None
|
652 |
+
self.U = None
|
653 |
+
self.perm_c = perm_c
|
654 |
+
self.perm_r = perm_r
|
655 |
+
# self.splu = cu_splu(csr_matrix(lap))
|
656 |
+
# self.L = self.splu.L
|
657 |
+
# self.U = self.splu.U
|
658 |
+
# self.perm_c = self.splu.perm_c
|
659 |
+
# self.perm_r = self.splu.perm_r
|
660 |
+
self.__device = None
|
661 |
+
|
662 |
+
def to(self, device):
|
663 |
+
# assumes to receive a pytorch device object that has a "index" field
|
664 |
+
# print(device)
|
665 |
+
# if(self.__device is None):
|
666 |
+
# raise Exception()
|
667 |
+
self.__device = device.index
|
668 |
+
with cupy.cuda.Device(self.__device):
|
669 |
+
# self.__orgL = cupy.asarray(self.__orgL)
|
670 |
+
# self.__orgU = cupy.asarray(self.__orgU)
|
671 |
+
self.L = csr_matrix(self.__orgL)
|
672 |
+
self.U = csr_matrix(self.__orgU)
|
673 |
+
return self
|
674 |
+
|
675 |
+
def device(self):
|
676 |
+
return self.__device
|
677 |
+
|
678 |
+
def solve(self, b):
|
679 |
+
""" an attempt to use SuperLU data to efficiently solve
|
680 |
+
Ax = Pr.T L U Pc.T x = b
|
681 |
+
- note that L from SuperLU is in CSC format solving for c
|
682 |
+
results in an efficiency warning
|
683 |
+
Pr . A . Pc = L . U
|
684 |
+
Lc = b - forward solve for c
|
685 |
+
c = Ux - then back solve for x
|
686 |
+
"""
|
687 |
+
|
688 |
+
assert self.__device is not None, "need to explicitly call to() before solving"
|
689 |
+
if USE_UGLY_PATCH_FOR_CUPY_ERROR:
|
690 |
+
with cupy.cuda.Device(0):
|
691 |
+
b[:1, :1].copy()[:, :1]
|
692 |
+
|
693 |
+
with cupy.cuda.Device(self.__device):
|
694 |
+
b = cupy.array(b)
|
695 |
+
if self.perm_r is not None:
|
696 |
+
b_old = b.copy()
|
697 |
+
b[self.perm_r] = b_old
|
698 |
+
|
699 |
+
assert b.device.id == self.__device, "got device" + str(b.device.id) + "instead of" + str(self.__device)
|
700 |
+
# st = time.time()
|
701 |
+
try: # unit_diagonal is a new kw
|
702 |
+
c = spsolve_triangular(self.L, b, lower=True, unit_diagonal=True, overwrite_b=True)
|
703 |
+
except TypeError:
|
704 |
+
c = spsolve_triangular(self.L, b, lower=True, overwrite_b=True)
|
705 |
+
px = spsolve_triangular(self.U, c, lower=False, overwrite_b=True)
|
706 |
+
# print(f"time for spsolve_triangular GPU: {time.time() - st}" )
|
707 |
+
|
708 |
+
if self.perm_c is None:
|
709 |
+
return px
|
710 |
+
px = px[self.perm_c]
|
711 |
+
|
712 |
+
# print(f'used: {mempool.used_bytes()}')
|
713 |
+
# print(f'total: {mempool.total_bytes()}')
|
714 |
+
return px
|
715 |
+
|
716 |
+
|
717 |
+
class MyCuSPLU_CPU:
|
718 |
+
'''
|
719 |
+
implmentation of SPLU on the gpu via CuPy
|
720 |
+
'''
|
721 |
+
def __init__(self, L, U, perm_c=None, perm_r=None):
|
722 |
+
# with cupy.cuda.Device(device):
|
723 |
+
self.__orgL = L
|
724 |
+
self.__orgU = U
|
725 |
+
# self.L = csr_matrix(L)
|
726 |
+
# self.U = csr_matrix(U)
|
727 |
+
self.L = L
|
728 |
+
self.U = U
|
729 |
+
# self.L = L.tocsr()
|
730 |
+
# self.U = U.tocsr()
|
731 |
+
self.perm_c = perm_c
|
732 |
+
self.perm_r = perm_r
|
733 |
+
# self.splu = cu_splu(csr_matrix(lap))
|
734 |
+
# self.L = self.splu.L
|
735 |
+
# self.U = self.splu.U
|
736 |
+
# self.perm_c = self.splu.perm_c
|
737 |
+
# self.perm_r = self.splu.perm_r
|
738 |
+
self.__device = 'cpu'
|
739 |
+
|
740 |
+
def to(self, device):
|
741 |
+
# assumes to receive a pytorch device object that has a "index" field
|
742 |
+
# print(device)
|
743 |
+
# if(self.__device is None):
|
744 |
+
# raise Exception()
|
745 |
+
# self.__device = device.index
|
746 |
+
# with cupy.cuda.Device(self.__device):
|
747 |
+
# # self.__orgL = cupy.asarray(self.__orgL)
|
748 |
+
# # self.__orgU = cupy.asarray(self.__orgU)
|
749 |
+
# self.L = csr_matrix(self.__orgL)
|
750 |
+
# self.U = csr_matrix(self.__orgU)
|
751 |
+
return self
|
752 |
+
|
753 |
+
def device(self):
|
754 |
+
return self.__device
|
755 |
+
|
756 |
+
def solve(self, b):
|
757 |
+
""" an attempt to use SuperLU data to efficiently solve
|
758 |
+
Ax = Pr.T L U Pc.T x = b
|
759 |
+
- note that L from SuperLU is in CSC format solving for c
|
760 |
+
results in an efficiency warning
|
761 |
+
Pr . A . Pc = L . U
|
762 |
+
Lc = b - forward solve for c
|
763 |
+
c = Ux - then back solve for x
|
764 |
+
"""
|
765 |
+
|
766 |
+
|
767 |
+
# Could be done on GPU
|
768 |
+
if self.perm_r is not None:
|
769 |
+
b_old = b.copy()
|
770 |
+
b[self.perm_r] = b_old
|
771 |
+
# , permc_spec="NATURAL"
|
772 |
+
# , permc_spec="NATURAL"
|
773 |
+
# , permc_spec="NATURAL"
|
774 |
+
st = time.time()
|
775 |
+
# try: # unit_diagonal is a new kw
|
776 |
+
# c = spsolve_triangular(self.L, b, lower=True, unit_diagonal=True, overwrite_b=True)
|
777 |
+
# except TypeError:
|
778 |
+
# c = spsolve_triangular(self.L, b, lower=True, overwrite_b=True)
|
779 |
+
# px = spsolve_triangular(self.U, c, lower=False, overwrite_b=True)
|
780 |
+
try: # unit_diagonal is a new kw
|
781 |
+
c = spsolve(self.L, b, permc_spec="NATURAL")
|
782 |
+
except TypeError:
|
783 |
+
c = spsolve(self.L, b, permc_spec="NATURAL")
|
784 |
+
px = spsolve(self.U, c, permc_spec="NATURAL")
|
785 |
+
# # (self.L * c) - b / np.norm(b)
|
786 |
+
print(f"time for spsolve_triangular CPU: {time.time() - st}" )
|
787 |
+
|
788 |
+
if self.perm_c is None:
|
789 |
+
return px
|
790 |
+
px = px[self.perm_c]
|
791 |
+
|
792 |
+
# print(f'used: {mempool.used_bytes()}')
|
793 |
+
# print(f'total: {mempool.total_bytes()}')
|
794 |
+
return px
|
795 |
+
|
796 |
+
# return cupy.asnumpy(px)
|
NeuralJacobianFields/SourceMesh.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy
|
4 |
+
import torch
|
5 |
+
import igl
|
6 |
+
from . import MeshProcessor
|
7 |
+
WKS_DIM = MeshProcessor.WKS_DIM
|
8 |
+
WKS_FACTOR = 1000
|
9 |
+
import numpy as np
|
10 |
+
import sys
|
11 |
+
import random
|
12 |
+
import time
|
13 |
+
class SourceMesh:
|
14 |
+
'''
|
15 |
+
datastructure for the source mesh to be mapped
|
16 |
+
'''
|
17 |
+
|
18 |
+
def __init__(self, source_ind, source_dir, extra_source_fields,
|
19 |
+
random_scale, ttype, use_wks=False, random_centering=False,
|
20 |
+
cpuonly=False):
|
21 |
+
self.__use_wks = use_wks
|
22 |
+
self.source_ind = source_ind
|
23 |
+
self.source_dir = source_dir
|
24 |
+
self.centroids_and_normals = None
|
25 |
+
self.center_source = True
|
26 |
+
self.poisson = None
|
27 |
+
self.splu = None
|
28 |
+
self.__source_global_translation_to_original = 0
|
29 |
+
self.__extra_keys = extra_source_fields
|
30 |
+
self.__loaded_data = {}
|
31 |
+
self.__ttype = ttype
|
32 |
+
self.__random_scale = random_scale
|
33 |
+
self.random_centering = random_centering
|
34 |
+
self.source_mesh_centroid = None
|
35 |
+
self.mesh_processor = None
|
36 |
+
self.cpuonly = cpuonly
|
37 |
+
|
38 |
+
def get_vertices(self):
|
39 |
+
return self.source_vertices
|
40 |
+
|
41 |
+
def get_global_translation_to_original(self):
|
42 |
+
return self.__source_global_translation_to_original
|
43 |
+
|
44 |
+
def vertices_from_jacobians(self, d):
|
45 |
+
return self.poisson.solve_poisson(d)
|
46 |
+
# return self.splu.solve(d)
|
47 |
+
def jacobians_from_vertices(self, v):
|
48 |
+
return self.poisson.jacobians_from_vertices(v)
|
49 |
+
|
50 |
+
def restrict_jacobians(self, J):
|
51 |
+
return self.poisson.restrict_jacobians(J)
|
52 |
+
|
53 |
+
def get_loaded_data(self, key: str):
|
54 |
+
|
55 |
+
return self.__loaded_data.get(key)
|
56 |
+
|
57 |
+
def get_source_triangles(self):
|
58 |
+
# if self.__source_triangles is None:
|
59 |
+
# self.__source_triangles = np.load(os.path.join(self.source_dir, 'faces.npy'))
|
60 |
+
return self.mesh_processor.get_faces()
|
61 |
+
|
62 |
+
def to(self, device):
|
63 |
+
self.poisson = self.poisson.to(device)
|
64 |
+
self.splu = self.splu.to(device)
|
65 |
+
self.centroids_and_normals = self.centroids_and_normals.to(device)
|
66 |
+
for key in self.__loaded_data.keys():
|
67 |
+
self.__loaded_data[key] = self.__loaded_data[key].to(device)
|
68 |
+
return self
|
69 |
+
|
70 |
+
def __init_from_mesh_data(self):
|
71 |
+
assert self.mesh_processor is not None
|
72 |
+
self.mesh_processor.prepare_differential_operators_for_use(self.__ttype) #call 1
|
73 |
+
self.source_vertices = torch.from_numpy(self.mesh_processor.get_vertices()).type(
|
74 |
+
self.__ttype)
|
75 |
+
if self.__random_scale != 1:
|
76 |
+
print("Diff ops and WKS need to be multiplied accordingly. Not implemented for now")
|
77 |
+
sys.exit()
|
78 |
+
self.source_vertices *= self.__random_scale
|
79 |
+
|
80 |
+
bb = igl.bounding_box(self.source_vertices.numpy())[0]
|
81 |
+
diag = igl.bounding_box_diagonal(self.source_vertices.numpy())
|
82 |
+
|
83 |
+
# self.source_mesh_centroid = torch.mean(self.source_vertices, axis=0)
|
84 |
+
self.source_mesh_centroid = (bb[0] + bb[-1])/2
|
85 |
+
if self.random_centering:
|
86 |
+
# centering augmentation
|
87 |
+
self.source_mesh_centroid = self.source_mesh_centroid + [(2*random.random() - 1)*diag*0.2, (2*random.random() - 1)*diag*0.2, (2*random.random() - 1)*diag*0.2]
|
88 |
+
# self.source_mesh_centroid = (bb[0] + bb[-1])/2 - np.array([-0.00033245, -0.2910367 , 0.02100835])
|
89 |
+
|
90 |
+
# Load input to NJF MLP
|
91 |
+
# start = time.time()
|
92 |
+
centroids = self.mesh_processor.get_centroids()
|
93 |
+
centroid_points_and_normals = centroids.points_and_normals
|
94 |
+
if self.__use_wks:
|
95 |
+
wks = WKS_FACTOR * centroids.wks
|
96 |
+
centroid_points_and_normals = numpy.hstack((centroid_points_and_normals, wks))
|
97 |
+
self.centroids_and_normals = torch.from_numpy(
|
98 |
+
centroid_points_and_normals).type(self.__ttype)
|
99 |
+
if self.center_source:
|
100 |
+
c = self.source_mesh_centroid
|
101 |
+
self.centroids_and_normals[:, 0:3] -= c
|
102 |
+
self.source_vertices -= c
|
103 |
+
self.__source_global_translation_to_original = c
|
104 |
+
self.poisson = self.mesh_processor.diff_ops.poisson_solver
|
105 |
+
self.splu = self.mesh_processor.diff_ops.MyCuSPLU_solver
|
106 |
+
|
107 |
+
|
108 |
+
# Essentially here we load pointnet data and apply the same preprocessing
|
109 |
+
for key in self.__extra_keys:
|
110 |
+
data = self.mesh_processor.get_data(key)
|
111 |
+
# if data is None: # not found in mesh data so try loading from disk
|
112 |
+
# data = np.load(os.path.join(self.source_dir, key + ".npy"))
|
113 |
+
data = torch.from_numpy(data)
|
114 |
+
if key == 'samples':
|
115 |
+
if self.center_source:
|
116 |
+
data -= self.get_mesh_centroid()
|
117 |
+
scale = self.__random_scale
|
118 |
+
data *= scale
|
119 |
+
data = data.unsqueeze(0).type(self.__ttype)
|
120 |
+
|
121 |
+
self.__loaded_data[key] = data
|
122 |
+
# print("Ellapsed load source mesh ", time.time() - start)
|
123 |
+
|
124 |
+
def load(self, source_v=None, source_f=None):
|
125 |
+
# mesh_data = SourceMeshData.SourceMeshData.meshprocessor_from_file(self.source_dir)
|
126 |
+
if source_v is not None and source_f is not None:
|
127 |
+
self.mesh_processor = MeshProcessor.MeshProcessor.meshprocessor_from_array(source_v,source_f, self.source_dir, self.__ttype, cpuonly=self.cpuonly, load_wks_samples=self.__use_wks, load_wks_centroids=self.__use_wks)
|
128 |
+
else:
|
129 |
+
if os.path.isdir(self.source_dir):
|
130 |
+
self.mesh_processor = MeshProcessor.MeshProcessor.meshprocessor_from_directory(self.source_dir, self.__ttype, cpuonly=self.cpuonly, load_wks_samples=self.__use_wks, load_wks_centroids=self.__use_wks)
|
131 |
+
else:
|
132 |
+
self.mesh_processor = MeshProcessor.MeshProcessor.meshprocessor_from_file(self.source_dir, self.__ttype, cpuonly=self.cpuonly, load_wks_samples=self.__use_wks, load_wks_centroids=self.__use_wks)
|
133 |
+
self.__init_from_mesh_data()
|
134 |
+
|
135 |
+
def get_point_dim(self):
|
136 |
+
return self.centroids_and_normals.shape[1]
|
137 |
+
|
138 |
+
def get_centroids_and_normals(self):
|
139 |
+
return self.centroids_and_normals
|
140 |
+
|
141 |
+
def get_mesh_centroid(self):
|
142 |
+
return self.source_mesh_centroid
|
143 |
+
|
144 |
+
def pin_memory(self):
|
145 |
+
# self.poisson.pin_memory()
|
146 |
+
# self.centroids_and_normals.pin_memory()
|
147 |
+
# self.source_vertices.pin_memory()
|
148 |
+
# for key in self.__loaded_data.keys():
|
149 |
+
# self.__loaded_data[key].pin_memory()
|
150 |
+
return self
|
151 |
+
|
152 |
+
|
app.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import yaml
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
import gradio as gr
|
8 |
+
from pathlib import Path
|
9 |
+
import tempfile
|
10 |
+
import shutil
|
11 |
+
|
12 |
+
# Add the current directory to Python path
|
13 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
14 |
+
|
15 |
+
# Add packages directory to Python path
|
16 |
+
packages_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'packages')
|
17 |
+
if os.path.exists(packages_dir):
|
18 |
+
sys.path.append(packages_dir)
|
19 |
+
|
20 |
+
try:
|
21 |
+
from loop import loop
|
22 |
+
except ImportError as e:
|
23 |
+
print(f"Error importing loop: {e}")
|
24 |
+
print("Make sure all dependencies are installed correctly")
|
25 |
+
sys.exit(1)
|
26 |
+
|
27 |
+
# Global variables for configuration
|
28 |
+
DEFAULT_CONFIG = {
|
29 |
+
'output_path': './outputs',
|
30 |
+
'gpu': 0,
|
31 |
+
'seed': 99,
|
32 |
+
'clip_model': 'ViT-B/32',
|
33 |
+
'consistency_clip_model': 'ViT-B/32',
|
34 |
+
'consistency_vit_stride': 8,
|
35 |
+
'consistency_vit_layer': 11,
|
36 |
+
'mesh': './meshes/longsleeve.obj',
|
37 |
+
'target_mesh': './meshes_target/jacket_sdf_new.obj',
|
38 |
+
'retriangulate': 0,
|
39 |
+
'bsdf': 'diffuse',
|
40 |
+
'lr': 0.0025,
|
41 |
+
'epochs': 1800,
|
42 |
+
'clip_weight': 2.5,
|
43 |
+
'delta_clip_weight': 5,
|
44 |
+
'vgg_weight': 0.0,
|
45 |
+
'face_weight': 0,
|
46 |
+
'regularize_jacobians_weight': 0.15,
|
47 |
+
'consistency_loss_weight': 0,
|
48 |
+
'consistency_elev_filter': 30,
|
49 |
+
'consistency_azim_filter': 20,
|
50 |
+
'batch_size': 24,
|
51 |
+
'train_res': 512,
|
52 |
+
'resize_method': 'cubic',
|
53 |
+
'fov_min': 30.0,
|
54 |
+
'fov_max': 90.0,
|
55 |
+
'dist_min': 2.5,
|
56 |
+
'dist_max': 3.5,
|
57 |
+
'light_power': 5.0,
|
58 |
+
'elev_alpha': 1.0,
|
59 |
+
'elev_beta': 5.0,
|
60 |
+
'elev_max': 60.0,
|
61 |
+
'azim_min': 0.0,
|
62 |
+
'azim_max': 360.0,
|
63 |
+
'aug_loc': 1,
|
64 |
+
'aug_light': 1,
|
65 |
+
'aug_bkg': 0,
|
66 |
+
'adapt_dist': 1,
|
67 |
+
'log_interval': 5,
|
68 |
+
'log_interval_im': 150,
|
69 |
+
'log_elev': 0,
|
70 |
+
'log_fov': 60.0,
|
71 |
+
'log_dist': 3.0,
|
72 |
+
'log_res': 512,
|
73 |
+
'log_light_power': 3.0
|
74 |
+
}
|
75 |
+
|
76 |
+
def process_garment(text_prompt, base_text_prompt, epochs, learning_rate, clip_weight, delta_clip_weight, progress=gr.Progress()):
|
77 |
+
"""
|
78 |
+
Main function to process garment generation
|
79 |
+
"""
|
80 |
+
try:
|
81 |
+
# Create a temporary output directory
|
82 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
83 |
+
# Update configuration
|
84 |
+
config = DEFAULT_CONFIG.copy()
|
85 |
+
config.update({
|
86 |
+
'output_path': temp_dir,
|
87 |
+
'text_prompt': text_prompt,
|
88 |
+
'base_text_prompt': base_text_prompt,
|
89 |
+
'epochs': int(epochs),
|
90 |
+
'lr': float(learning_rate),
|
91 |
+
'clip_weight': float(clip_weight),
|
92 |
+
'delta_clip_weight': float(delta_clip_weight),
|
93 |
+
'gpu': 0 # Use first GPU
|
94 |
+
})
|
95 |
+
|
96 |
+
# Set random seeds
|
97 |
+
random.seed(config['seed'])
|
98 |
+
os.environ['PYTHONHASHSEED'] = str(config['seed'])
|
99 |
+
np.random.seed(config['seed'])
|
100 |
+
torch.manual_seed(config['seed'])
|
101 |
+
torch.cuda.manual_seed(config['seed'])
|
102 |
+
torch.backends.cudnn.deterministic = True
|
103 |
+
|
104 |
+
progress(0.1, desc="Initializing...")
|
105 |
+
|
106 |
+
# Run the main processing loop
|
107 |
+
loop(config)
|
108 |
+
|
109 |
+
progress(0.9, desc="Processing complete, preparing output...")
|
110 |
+
|
111 |
+
# Look for output files
|
112 |
+
output_files = []
|
113 |
+
for file_path in Path(temp_dir).rglob("*"):
|
114 |
+
if file_path.is_file() and file_path.suffix.lower() in ['.obj', '.png', '.jpg', '.jpeg', '.gif', '.mp4']:
|
115 |
+
output_files.append(str(file_path))
|
116 |
+
|
117 |
+
if output_files:
|
118 |
+
return output_files[0] if len(output_files) == 1 else output_files
|
119 |
+
else:
|
120 |
+
return "Processing completed but no output files found."
|
121 |
+
|
122 |
+
except Exception as e:
|
123 |
+
return f"Error during processing: {str(e)}"
|
124 |
+
|
125 |
+
def create_interface():
|
126 |
+
"""
|
127 |
+
Create the Gradio interface
|
128 |
+
"""
|
129 |
+
with gr.Blocks(title="Garment3DGen - 3D Garment Stylization", theme=gr.themes.Soft()) as interface:
|
130 |
+
gr.Markdown("""
|
131 |
+
# Garment3DGen: 3D Garment Stylization and Texture Generation
|
132 |
+
|
133 |
+
This tool allows you to stylize 3D garments using text prompts. Upload a 3D mesh and describe the desired style to generate a new 3D garment.
|
134 |
+
|
135 |
+
## How to use:
|
136 |
+
1. Enter a text prompt describing the target style (e.g., "leather jacket with studs")
|
137 |
+
2. Enter a base text prompt describing the input mesh (e.g., "simple t-shirt")
|
138 |
+
3. Adjust the parameters as needed
|
139 |
+
4. Click "Generate" to start the process
|
140 |
+
|
141 |
+
**Note:** Processing may take several minutes depending on the number of epochs.
|
142 |
+
""")
|
143 |
+
|
144 |
+
with gr.Row():
|
145 |
+
with gr.Column(scale=1):
|
146 |
+
gr.Markdown("### Input Parameters")
|
147 |
+
|
148 |
+
text_prompt = gr.Textbox(
|
149 |
+
label="Target Text Prompt",
|
150 |
+
placeholder="e.g., leather jacket with studs, denim jacket with patches",
|
151 |
+
value="leather jacket with studs"
|
152 |
+
)
|
153 |
+
|
154 |
+
base_text_prompt = gr.Textbox(
|
155 |
+
label="Base Text Prompt",
|
156 |
+
placeholder="e.g., simple t-shirt, basic long sleeve shirt",
|
157 |
+
value="simple t-shirt"
|
158 |
+
)
|
159 |
+
|
160 |
+
epochs = gr.Slider(
|
161 |
+
minimum=100,
|
162 |
+
maximum=3000,
|
163 |
+
value=1800,
|
164 |
+
step=100,
|
165 |
+
label="Number of Epochs",
|
166 |
+
info="More epochs = better quality but longer processing time"
|
167 |
+
)
|
168 |
+
|
169 |
+
learning_rate = gr.Slider(
|
170 |
+
minimum=0.0001,
|
171 |
+
maximum=0.01,
|
172 |
+
value=0.0025,
|
173 |
+
step=0.0001,
|
174 |
+
label="Learning Rate"
|
175 |
+
)
|
176 |
+
|
177 |
+
clip_weight = gr.Slider(
|
178 |
+
minimum=0.1,
|
179 |
+
maximum=10.0,
|
180 |
+
value=2.5,
|
181 |
+
step=0.1,
|
182 |
+
label="CLIP Weight"
|
183 |
+
)
|
184 |
+
|
185 |
+
delta_clip_weight = gr.Slider(
|
186 |
+
minimum=0.1,
|
187 |
+
maximum=20.0,
|
188 |
+
value=5.0,
|
189 |
+
step=0.1,
|
190 |
+
label="Delta CLIP Weight"
|
191 |
+
)
|
192 |
+
|
193 |
+
generate_btn = gr.Button("Generate 3D Garment", variant="primary")
|
194 |
+
|
195 |
+
with gr.Column(scale=1):
|
196 |
+
gr.Markdown("### Output")
|
197 |
+
output = gr.File(label="Generated 3D Garment")
|
198 |
+
status = gr.Textbox(label="Status", interactive=False)
|
199 |
+
|
200 |
+
# Connect the button to the processing function
|
201 |
+
generate_btn.click(
|
202 |
+
fn=process_garment,
|
203 |
+
inputs=[text_prompt, base_text_prompt, epochs, learning_rate, clip_weight, delta_clip_weight],
|
204 |
+
outputs=[output]
|
205 |
+
)
|
206 |
+
|
207 |
+
gr.Markdown("""
|
208 |
+
## Tips for better results:
|
209 |
+
- Be specific in your text prompts
|
210 |
+
- Use descriptive terms for materials, colors, and styles
|
211 |
+
- The base text prompt should accurately describe your input mesh
|
212 |
+
- Higher epoch counts generally produce better results but take longer
|
213 |
+
- Experiment with different CLIP weights for different effects
|
214 |
+
|
215 |
+
## Technical Details:
|
216 |
+
This tool uses Neural Jacobian Fields and CLIP embeddings to deform and stylize 3D garment meshes.
|
217 |
+
The process involves optimizing the mesh geometry and texture to match the target text description.
|
218 |
+
""")
|
219 |
+
|
220 |
+
return interface
|
221 |
+
|
222 |
+
if __name__ == "__main__":
|
223 |
+
# Create and launch the interface
|
224 |
+
interface = create_interface()
|
225 |
+
interface.launch(
|
226 |
+
server_name="0.0.0.0",
|
227 |
+
server_port=7860,
|
228 |
+
share=False,
|
229 |
+
debug=True
|
230 |
+
)
|
asset_visualization/armor.gif
ADDED
![]() |
Git LFS Details
|
example_config.yml
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
output_path: ./outputs
|
2 |
+
gpu: 0
|
3 |
+
seed: 99
|
4 |
+
|
5 |
+
# CLIP-related
|
6 |
+
clip_model: ViT-B/32
|
7 |
+
consistency_clip_model: ViT-B/32
|
8 |
+
consistency_vit_stride: 8
|
9 |
+
consistency_vit_layer: 11
|
10 |
+
|
11 |
+
# Mesh
|
12 |
+
mesh: ./meshes/longsleeve.obj
|
13 |
+
target_mesh: ./meshes_target/jacket_sdf_new.obj
|
14 |
+
retriangulate: 0
|
15 |
+
|
16 |
+
# Render settings
|
17 |
+
bsdf: diffuse
|
18 |
+
|
19 |
+
# Hyper-parameters
|
20 |
+
lr: 0.0025 # 0.0025
|
21 |
+
epochs: 1800
|
22 |
+
clip_weight: 2.5 #0.5 #20.0 # 20.0 #2 for garments
|
23 |
+
delta_clip_weight: 5 #10.0 # 2 for garments
|
24 |
+
vgg_weight: 0.0
|
25 |
+
face_weight: 0
|
26 |
+
regularize_jacobians_weight: 0.15 # 0.15 #0.5
|
27 |
+
consistency_loss_weight: 0 #-0.25 #-0.15 #-0.5 # -0.25 #-0.5 # -0.25 for garments
|
28 |
+
consistency_elev_filter: 30
|
29 |
+
consistency_azim_filter: 20
|
30 |
+
batch_size: 24
|
31 |
+
train_res: 512
|
32 |
+
resize_method: cubic
|
33 |
+
|
34 |
+
# Camera parameters
|
35 |
+
fov_min: 30.0
|
36 |
+
fov_max: 90.0
|
37 |
+
dist_min: 2.5
|
38 |
+
dist_max: 3.5
|
39 |
+
light_power: 5.0
|
40 |
+
elev_alpha: 1.0
|
41 |
+
elev_beta: 5.0
|
42 |
+
elev_max: 60.0
|
43 |
+
azim_min: 0.0
|
44 |
+
azim_max: 360.0
|
45 |
+
aug_loc: 1
|
46 |
+
aug_light: 1
|
47 |
+
aug_bkg: 0
|
48 |
+
adapt_dist: 1
|
49 |
+
|
50 |
+
log_interval: 5
|
51 |
+
log_interval_im: 150
|
52 |
+
log_elev: 0
|
53 |
+
log_fov: 60.0
|
54 |
+
log_dist: 3.0
|
55 |
+
log_res: 512
|
56 |
+
log_light_power: 3.0
|
get_embeddings.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
|
5 |
+
def get_fashion_text_embeddings(fclip, cfg, device):
|
6 |
+
print(f'Target text prompt is {cfg.text_prompt}')
|
7 |
+
print(f'Base text prompt is {cfg.base_text_prompt}')
|
8 |
+
with torch.no_grad():
|
9 |
+
text_embeds = fclip.encode_text_tensors([cfg.text_prompt]).detach()
|
10 |
+
base_text_embeds = fclip.encode_text_tensors([cfg.base_text_prompt]).detach()
|
11 |
+
target_text_embeds = text_embeds.clone() / text_embeds.norm(dim=1, keepdim=True)
|
12 |
+
delta_text_embeds = text_embeds - base_text_embeds
|
13 |
+
delta_text_embeds = delta_text_embeds / delta_text_embeds.norm(dim=1, keepdim=True)
|
14 |
+
return target_text_embeds.to(device), delta_text_embeds.to(device)
|
15 |
+
|
16 |
+
|
17 |
+
def get_fashion_img_embeddings(fclip, cfg, device, normalize=True):
|
18 |
+
print(f'Target image path is {cfg.image_prompt}')
|
19 |
+
print(f'Base image path is {cfg.base_image_prompt}')
|
20 |
+
with torch.no_grad():
|
21 |
+
target_image_embeds = fclip.encode_images([cfg.image_prompt], 1)
|
22 |
+
target_image_embeds = torch.tensor(target_image_embeds, device=device).detach()
|
23 |
+
|
24 |
+
base_image_embeds = fclip.encode_images([cfg.base_image_prompt], 1)
|
25 |
+
base_image_embeds = torch.tensor(base_image_embeds, device=device).detach()
|
26 |
+
delta_img_embeds = target_image_embeds - base_image_embeds
|
27 |
+
if normalize:
|
28 |
+
delta_img_embeds = delta_img_embeds / delta_img_embeds.norm(dim=1, keepdim=True)
|
29 |
+
target_image_embeds = target_image_embeds.clone() / target_image_embeds.norm(dim=1, keepdim=True)
|
30 |
+
return target_image_embeds.to(device), delta_img_embeds.to(device)
|
31 |
+
|
32 |
+
|
33 |
+
def get_text_embeddings(clip, model, cfg, device):
|
34 |
+
print(f'Target text prompt is {cfg.text_prompt}')
|
35 |
+
print(f'Base text prompt is {cfg.base_text_prompt}')
|
36 |
+
text_embeds = clip.tokenize(cfg.text_prompt).to(device)
|
37 |
+
base_text_embeds = clip.tokenize(cfg.base_text_prompt).to(device)
|
38 |
+
with torch.no_grad():
|
39 |
+
text_embeds = model.encode_text(text_embeds).detach()
|
40 |
+
target_text_embeds = text_embeds.clone() / text_embeds.norm(dim=1, keepdim=True)
|
41 |
+
delta_text_embeds = text_embeds - model.encode_text(base_text_embeds)
|
42 |
+
delta_text_embeds = delta_text_embeds / delta_text_embeds.norm(dim=1, keepdim=True)
|
43 |
+
return target_text_embeds, delta_text_embeds
|
44 |
+
|
45 |
+
|
46 |
+
def get_img_embeddings(model, preprocess, cfg, device):
|
47 |
+
print(f'Target image path is {cfg.image_prompt}')
|
48 |
+
print(f'Base image path is {cfg.base_image_prompt}')
|
49 |
+
|
50 |
+
image = preprocess(Image.open(cfg.image_prompt)).unsqueeze(0).to(device)
|
51 |
+
base_image = preprocess(Image.open(cfg.base_image_prompt)).unsqueeze(0).to(device)
|
52 |
+
with torch.no_grad():
|
53 |
+
target_image_embeds = model.encode_image(image).to(device).detach()
|
54 |
+
base_image_embeds = model.encode_image(base_image).to(device)
|
55 |
+
|
56 |
+
delta_img_embeds = target_image_embeds - base_image_embeds
|
57 |
+
delta_img_embeds = delta_img_embeds / delta_img_embeds.norm(dim=1, keepdim=True)
|
58 |
+
return target_image_embeds, delta_img_embeds
|
loop.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import clip
|
2 |
+
import kornia
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import pathlib
|
6 |
+
import torchvision
|
7 |
+
import logging
|
8 |
+
import yaml
|
9 |
+
import nvdiffrast.torch as dr
|
10 |
+
from easydict import EasyDict
|
11 |
+
|
12 |
+
from NeuralJacobianFields import SourceMesh
|
13 |
+
|
14 |
+
from nvdiffmodeling.src import render
|
15 |
+
|
16 |
+
from torch.utils.tensorboard import SummaryWriter
|
17 |
+
from tqdm import tqdm
|
18 |
+
|
19 |
+
from utilities.video import Video
|
20 |
+
from utilities.helpers import cosine_avg, create_scene, l1_avg
|
21 |
+
from utilities.camera import CameraBatch, get_camera_params
|
22 |
+
from utilities.clip_spatial import CLIPVisualEncoder
|
23 |
+
from utilities.resize_right import resize, cubic, linear, lanczos2, lanczos3
|
24 |
+
from packages.fashion_clip.fashion_clip.fashion_clip import FashionCLIP
|
25 |
+
from utils import *
|
26 |
+
from get_embeddings import *
|
27 |
+
|
28 |
+
from pytorch3d.structures import Meshes
|
29 |
+
from pytorch3d.loss import (
|
30 |
+
chamfer_distance,
|
31 |
+
mesh_edge_loss,
|
32 |
+
mesh_laplacian_smoothing,
|
33 |
+
mesh_normal_consistency,
|
34 |
+
)
|
35 |
+
from pytorch3d.ops import sample_points_from_meshes
|
36 |
+
|
37 |
+
|
38 |
+
def total_triangle_area(vertices):
|
39 |
+
# Calculate the sum of the areas of all triangles in the mesh
|
40 |
+
num_triangles = vertices.shape[0] // 3
|
41 |
+
triangle_vertices = vertices.view(num_triangles, 3, 3)
|
42 |
+
|
43 |
+
# Calculate the cross product for each triangle
|
44 |
+
cross_products = torch.cross(triangle_vertices[:, 1] - triangle_vertices[:, 0],
|
45 |
+
triangle_vertices[:, 2] - triangle_vertices[:, 0])
|
46 |
+
|
47 |
+
# Calculate the area of each triangle
|
48 |
+
areas = 0.5 * torch.norm(cross_products, dim=1)
|
49 |
+
|
50 |
+
# Sum the areas of all triangles
|
51 |
+
total_area = torch.sum(areas)
|
52 |
+
return total_area
|
53 |
+
|
54 |
+
def triangle_size_regularization(vertices):
|
55 |
+
# Penalize small triangles by minimizing the squared sum of triangle areas
|
56 |
+
return total_triangle_area(vertices)**2
|
57 |
+
|
58 |
+
def loop(cfg):
|
59 |
+
clip_flag = True
|
60 |
+
output_path = pathlib.Path(cfg['output_path'])
|
61 |
+
os.makedirs(output_path, exist_ok=True)
|
62 |
+
with open(output_path / 'config.yml', 'w') as f:
|
63 |
+
yaml.dump(cfg, f, default_flow_style=False)
|
64 |
+
cfg = EasyDict(cfg)
|
65 |
+
|
66 |
+
print(f'Output directory {cfg.output_path} created')
|
67 |
+
os.makedirs(output_path / 'tmp', exist_ok=True)
|
68 |
+
|
69 |
+
device = torch.device(f'cuda:{cfg.gpu}')
|
70 |
+
torch.cuda.set_device(device)
|
71 |
+
|
72 |
+
text_input, image_input, fashion_image, fashion_text, use_target_mesh = False, False, False, False, True
|
73 |
+
CLIP_embeddings = False
|
74 |
+
|
75 |
+
if CLIP_embeddings:
|
76 |
+
print('Loading CLIP Models')
|
77 |
+
model, preprocess = clip.load(cfg.clip_model, device=device)
|
78 |
+
else:
|
79 |
+
fclip = FashionCLIP('fashion-clip')
|
80 |
+
|
81 |
+
fe = CLIPVisualEncoder(cfg.consistency_clip_model, cfg.consistency_vit_stride, device)
|
82 |
+
|
83 |
+
if fashion_text or fashion_image:
|
84 |
+
target_direction_embeds, delta_direction_embeds = get_fashion_img_embeddings(fclip, cfg, device, True)
|
85 |
+
elif text_input:
|
86 |
+
target_direction_embeds, delta_direction_embeds = get_text_embeddings(clip, model, cfg, device)
|
87 |
+
elif image_input:
|
88 |
+
target_direction_embeds, delta_direction_embeds = get_img_embeddings(model, preprocess, cfg, device)
|
89 |
+
|
90 |
+
clip_mean = torch.tensor([0.48154660, 0.45782750, 0.40821073], device=device)
|
91 |
+
clip_std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device)
|
92 |
+
|
93 |
+
# output video
|
94 |
+
video = Video(cfg.output_path)
|
95 |
+
# GL Context
|
96 |
+
glctx = dr.RasterizeGLContext()
|
97 |
+
|
98 |
+
|
99 |
+
load_mesh = get_mesh(cfg.mesh, output_path, cfg.retriangulate, cfg.bsdf)
|
100 |
+
|
101 |
+
if use_target_mesh:
|
102 |
+
target_mesh = get_mesh(cfg.target_mesh, output_path, cfg.retriangulate, cfg.bsdf, 'mesh_target.obj')
|
103 |
+
# We construct a Meshes structure for the target mesh
|
104 |
+
trg_mesh_p3d = Meshes(verts=[target_mesh.v_pos], faces=[target_mesh.t_pos_idx])
|
105 |
+
|
106 |
+
|
107 |
+
jacobian_source = SourceMesh.SourceMesh(0, str(output_path / 'tmp' / 'mesh.obj'), {}, 1, ttype=torch.float)
|
108 |
+
if len(list((output_path / 'tmp').glob('*.npz'))) > 0:
|
109 |
+
logging.warn(f'Using existing Jacobian .npz files in {str(output_path)}/tmp/ ! Please check if this is intentional.')
|
110 |
+
jacobian_source.load()
|
111 |
+
jacobian_source.to(device)
|
112 |
+
|
113 |
+
with torch.no_grad():
|
114 |
+
gt_jacobians = jacobian_source.jacobians_from_vertices(load_mesh.v_pos.unsqueeze(0))
|
115 |
+
gt_jacobians.requires_grad_(True)
|
116 |
+
|
117 |
+
optimizer = torch.optim.Adam([gt_jacobians], lr=cfg.lr)
|
118 |
+
cams_data = CameraBatch(
|
119 |
+
cfg.train_res,
|
120 |
+
[cfg.dist_min, cfg.dist_max],
|
121 |
+
[cfg.azim_min, cfg.azim_max],
|
122 |
+
[cfg.elev_alpha, cfg.elev_beta, cfg.elev_max],
|
123 |
+
[cfg.fov_min, cfg.fov_max],
|
124 |
+
cfg.aug_loc,
|
125 |
+
cfg.aug_light,
|
126 |
+
cfg.aug_bkg,
|
127 |
+
cfg.batch_size,
|
128 |
+
rand_solid=True
|
129 |
+
)
|
130 |
+
cams = torch.utils.data.DataLoader(cams_data, cfg.batch_size, num_workers=0, pin_memory=True)
|
131 |
+
best_losses = {'CLIP': np.inf, 'total': np.inf}
|
132 |
+
|
133 |
+
for out_type in ['final', 'best_clip', 'best_total', 'target_final']:
|
134 |
+
os.makedirs(output_path / f'mesh_{out_type}', exist_ok=True)
|
135 |
+
os.makedirs(output_path / 'images', exist_ok=True)
|
136 |
+
logger = SummaryWriter(str(output_path / 'logs'))
|
137 |
+
|
138 |
+
rot_ang = 0.0
|
139 |
+
t_loop = tqdm(range(cfg.epochs), leave=False)
|
140 |
+
|
141 |
+
if cfg.resize_method == 'cubic':
|
142 |
+
resize_method = cubic
|
143 |
+
elif cfg.resize_method == 'linear':
|
144 |
+
resize_method = linear
|
145 |
+
elif cfg.resize_method == 'lanczos2':
|
146 |
+
resize_method = lanczos2
|
147 |
+
elif cfg.resize_method == 'lanczos3':
|
148 |
+
resize_method = lanczos3
|
149 |
+
|
150 |
+
for it in t_loop:
|
151 |
+
|
152 |
+
# updated vertices from jacobians
|
153 |
+
n_vert = jacobian_source.vertices_from_jacobians(gt_jacobians).squeeze()
|
154 |
+
|
155 |
+
# TODO: More texture code required to make it work ...
|
156 |
+
ready_texture = texture.Texture2D(
|
157 |
+
kornia.filters.gaussian_blur2d(
|
158 |
+
load_mesh.material['kd'].data.permute(0, 3, 1, 2),
|
159 |
+
kernel_size=(7, 7),
|
160 |
+
sigma=(3, 3),
|
161 |
+
).permute(0, 2, 3, 1).contiguous()
|
162 |
+
)
|
163 |
+
|
164 |
+
kd_notex = texture.Texture2D(torch.full_like(ready_texture.data, 0.5))
|
165 |
+
|
166 |
+
ready_specular = texture.Texture2D(
|
167 |
+
kornia.filters.gaussian_blur2d(
|
168 |
+
load_mesh.material['ks'].data.permute(0, 3, 1, 2),
|
169 |
+
kernel_size=(7, 7),
|
170 |
+
sigma=(3, 3),
|
171 |
+
).permute(0, 2, 3, 1).contiguous()
|
172 |
+
)
|
173 |
+
|
174 |
+
ready_normal = texture.Texture2D(
|
175 |
+
kornia.filters.gaussian_blur2d(
|
176 |
+
load_mesh.material['normal'].data.permute(0, 3, 1, 2),
|
177 |
+
kernel_size=(7, 7),
|
178 |
+
sigma=(3, 3),
|
179 |
+
).permute(0, 2, 3, 1).contiguous()
|
180 |
+
)
|
181 |
+
|
182 |
+
# Final mesh
|
183 |
+
m = mesh.Mesh(
|
184 |
+
n_vert,
|
185 |
+
load_mesh.t_pos_idx,
|
186 |
+
material={
|
187 |
+
'bsdf': cfg.bsdf,
|
188 |
+
'kd': kd_notex,
|
189 |
+
'ks': ready_specular,
|
190 |
+
'normal': ready_normal,
|
191 |
+
},
|
192 |
+
base=load_mesh # gets uvs etc from here
|
193 |
+
)
|
194 |
+
|
195 |
+
deformed_mesh_p3d = Meshes(verts=[m.v_pos], faces=[m.t_pos_idx])
|
196 |
+
|
197 |
+
render_mesh = create_scene([m.eval()], sz=512)
|
198 |
+
if it == 0:
|
199 |
+
base_mesh = render_mesh.clone()
|
200 |
+
base_mesh = mesh.auto_normals(base_mesh)
|
201 |
+
base_mesh = mesh.compute_tangents(base_mesh)
|
202 |
+
render_mesh = mesh.auto_normals(render_mesh)
|
203 |
+
render_mesh = mesh.compute_tangents(render_mesh)
|
204 |
+
|
205 |
+
if use_target_mesh:
|
206 |
+
# Target mesh
|
207 |
+
m_target = mesh.Mesh(
|
208 |
+
target_mesh.v_pos,
|
209 |
+
target_mesh.t_pos_idx,
|
210 |
+
material={
|
211 |
+
'bsdf': cfg.bsdf,
|
212 |
+
'kd': kd_notex,
|
213 |
+
'ks': ready_specular,
|
214 |
+
'normal': ready_normal,
|
215 |
+
},
|
216 |
+
base=target_mesh
|
217 |
+
)
|
218 |
+
|
219 |
+
render_target_mesh = create_scene([m_target.eval()], sz=512)
|
220 |
+
if it == 0:
|
221 |
+
base_target_mesh = render_target_mesh.clone()
|
222 |
+
base_target_mesh = mesh.auto_normals(base_target_mesh)
|
223 |
+
base_target_mesh = mesh.compute_tangents(base_target_mesh)
|
224 |
+
render_target_mesh = mesh.auto_normals(render_target_mesh)
|
225 |
+
render_target_mesh = mesh.compute_tangents(render_target_mesh)
|
226 |
+
|
227 |
+
|
228 |
+
# Logging mesh
|
229 |
+
if it % cfg.log_interval == 0:
|
230 |
+
with torch.no_grad():
|
231 |
+
params = get_camera_params(
|
232 |
+
cfg.log_elev,
|
233 |
+
rot_ang,
|
234 |
+
cfg.log_dist,
|
235 |
+
cfg.log_res,
|
236 |
+
cfg.log_fov,
|
237 |
+
)
|
238 |
+
rot_ang += 5
|
239 |
+
log_mesh = mesh.unit_size(render_mesh.eval(params))
|
240 |
+
log_image = render.render_mesh(
|
241 |
+
glctx,
|
242 |
+
log_mesh,
|
243 |
+
params['mvp'],
|
244 |
+
params['campos'],
|
245 |
+
params['lightpos'],
|
246 |
+
cfg.log_light_power,
|
247 |
+
cfg.log_res,
|
248 |
+
1,
|
249 |
+
background=torch.ones(1, cfg.log_res, cfg.log_res, 3).to(device)
|
250 |
+
)
|
251 |
+
|
252 |
+
log_image = video.ready_image(log_image)
|
253 |
+
logger.add_mesh('predicted_mesh', vertices=log_mesh.v_pos.unsqueeze(0), faces=log_mesh.t_pos_idx.unsqueeze(0), global_step=it)
|
254 |
+
|
255 |
+
if cfg.adapt_dist and it > 0:
|
256 |
+
with torch.no_grad():
|
257 |
+
v_pos = m.v_pos.clone()
|
258 |
+
vmin = v_pos.amin(dim=0)
|
259 |
+
vmax = v_pos.amax(dim=0)
|
260 |
+
v_pos -= (vmin + vmax) / 2
|
261 |
+
mult = torch.cat([v_pos.amin(dim=0), v_pos.amax(dim=0)]).abs().amax().cpu()
|
262 |
+
cams.dataset.dist_min = cfg.dist_min * mult
|
263 |
+
cams.dataset.dist_max = cfg.dist_max * mult
|
264 |
+
|
265 |
+
params_camera = next(iter(cams))
|
266 |
+
for key in params_camera:
|
267 |
+
params_camera[key] = params_camera[key].to(device)
|
268 |
+
|
269 |
+
final_mesh = render_mesh.eval(params_camera)
|
270 |
+
train_render = render.render_mesh(
|
271 |
+
glctx,
|
272 |
+
final_mesh,
|
273 |
+
params_camera['mvp'],
|
274 |
+
params_camera['campos'],
|
275 |
+
params_camera['lightpos'],
|
276 |
+
cfg.light_power,
|
277 |
+
cfg.train_res,
|
278 |
+
spp=1,
|
279 |
+
num_layers=1,
|
280 |
+
msaa=False,
|
281 |
+
background=params_camera['bkgs']
|
282 |
+
).permute(0, 3, 1, 2)
|
283 |
+
train_render = resize(train_render, out_shape=(224, 224), interp_method=resize_method)
|
284 |
+
|
285 |
+
if use_target_mesh:
|
286 |
+
final_target_mesh = render_target_mesh.eval(params_camera)
|
287 |
+
train_target_render = render.render_mesh(
|
288 |
+
glctx,
|
289 |
+
final_target_mesh,
|
290 |
+
params_camera['mvp'],
|
291 |
+
params_camera['campos'],
|
292 |
+
params_camera['lightpos'],
|
293 |
+
cfg.light_power,
|
294 |
+
cfg.train_res,
|
295 |
+
spp=1,
|
296 |
+
num_layers=1,
|
297 |
+
msaa=False,
|
298 |
+
background=params_camera['bkgs']
|
299 |
+
).permute(0, 3, 1, 2)
|
300 |
+
train_target_render = resize(train_target_render, out_shape=(224, 224), interp_method=resize_method)
|
301 |
+
|
302 |
+
train_rast_map = render.render_mesh(
|
303 |
+
glctx,
|
304 |
+
final_mesh,
|
305 |
+
params_camera['mvp'],
|
306 |
+
params_camera['campos'],
|
307 |
+
params_camera['lightpos'],
|
308 |
+
cfg.light_power,
|
309 |
+
cfg.train_res,
|
310 |
+
spp=1,
|
311 |
+
num_layers=1,
|
312 |
+
msaa=False,
|
313 |
+
background=params_camera['bkgs'],
|
314 |
+
return_rast_map=True
|
315 |
+
)
|
316 |
+
|
317 |
+
if it == 0:
|
318 |
+
params_camera = next(iter(cams))
|
319 |
+
for key in params_camera:
|
320 |
+
params_camera[key] = params_camera[key].to(device)
|
321 |
+
base_render = render.render_mesh(
|
322 |
+
glctx,
|
323 |
+
base_mesh.eval(params_camera),
|
324 |
+
params_camera['mvp'],
|
325 |
+
params_camera['campos'],
|
326 |
+
params_camera['lightpos'],
|
327 |
+
cfg.light_power,
|
328 |
+
cfg.train_res,
|
329 |
+
spp=1,
|
330 |
+
num_layers=1,
|
331 |
+
msaa=False,
|
332 |
+
background=params_camera['bkgs'],
|
333 |
+
).permute(0, 3, 1, 2)
|
334 |
+
base_render = resize(base_render, out_shape=(224, 224), interp_method=resize_method)
|
335 |
+
|
336 |
+
if it % cfg.log_interval_im == 0:
|
337 |
+
log_idx = torch.randperm(cfg.batch_size)[:5]
|
338 |
+
s_log = train_render[log_idx, :, :, :]
|
339 |
+
s_log = torchvision.utils.make_grid(s_log)
|
340 |
+
ndarr = s_log.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
|
341 |
+
im = Image.fromarray(ndarr)
|
342 |
+
im.save(str(output_path / 'images' / f'epoch_{it}.png'))
|
343 |
+
|
344 |
+
if use_target_mesh:
|
345 |
+
s_log_target = train_target_render[log_idx, :, :, :]
|
346 |
+
s_log_target = torchvision.utils.make_grid(s_log_target)
|
347 |
+
ndarr = s_log_target.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
|
348 |
+
im = Image.fromarray(ndarr)
|
349 |
+
im.save(str(output_path / 'images' / f'epoch_{it}_target.png'))
|
350 |
+
|
351 |
+
obj.write_obj(
|
352 |
+
str(output_path / 'mesh_final'),
|
353 |
+
m.eval()
|
354 |
+
)
|
355 |
+
|
356 |
+
optimizer.zero_grad()
|
357 |
+
|
358 |
+
|
359 |
+
normalized_clip_render = (train_render - clip_mean[None, :, None, None]) / clip_std[None, :, None, None]
|
360 |
+
|
361 |
+
deformed_features = fclip.encode_image_tensors(train_render)
|
362 |
+
target_features = fclip.encode_image_tensors(train_target_render)
|
363 |
+
garment_loss = l1_avg(deformed_features, target_features)
|
364 |
+
l1_loss = l1_avg(train_render, train_target_render)
|
365 |
+
|
366 |
+
# We sample 10k points from the surface of each mesh
|
367 |
+
sample_src = sample_points_from_meshes(deformed_mesh_p3d, 10000)
|
368 |
+
sample_trg = sample_points_from_meshes(trg_mesh_p3d, 10000)
|
369 |
+
|
370 |
+
# We compare the two sets of pointclouds by computing (a) the chamfer loss
|
371 |
+
loss_chamfer, _ = chamfer_distance(sample_trg, sample_src)
|
372 |
+
loss_chamfer *= 25.
|
373 |
+
#
|
374 |
+
# and (b) the edge length of the predicted mesh
|
375 |
+
loss_edge = mesh_edge_loss(deformed_mesh_p3d)
|
376 |
+
|
377 |
+
# mesh normal consistency
|
378 |
+
loss_normal = mesh_normal_consistency(deformed_mesh_p3d)
|
379 |
+
|
380 |
+
# mesh laplacian smoothing
|
381 |
+
loss_laplacian = mesh_laplacian_smoothing(deformed_mesh_p3d, method="uniform")
|
382 |
+
|
383 |
+
loss_triangles = triangle_size_regularization(deformed_mesh_p3d.verts_list()[0])/100000.
|
384 |
+
|
385 |
+
logger.add_scalar('l1_loss', l1_loss, global_step=it)
|
386 |
+
logger.add_scalar('garment_loss', garment_loss, global_step=it)
|
387 |
+
|
388 |
+
# Jacobian regularization
|
389 |
+
r_loss = (((gt_jacobians) - torch.eye(3, 3, device=device)) ** 2).mean()
|
390 |
+
logger.add_scalar('jacobian_regularization', r_loss, global_step=it)
|
391 |
+
|
392 |
+
if cfg.consistency_loss_weight != 0:
|
393 |
+
consistency_loss = compute_mv_cl(final_mesh, fe, normalized_clip_render, params_camera, train_rast_map, cfg, device)
|
394 |
+
else:
|
395 |
+
consistency_loss = r_loss
|
396 |
+
logger.add_scalar('consistency_loss', consistency_loss, global_step=it)
|
397 |
+
|
398 |
+
logger.add_scalar('chamfer', loss_chamfer, global_step=it)
|
399 |
+
logger.add_scalar('edge', loss_edge, global_step=it)
|
400 |
+
logger.add_scalar('normal', loss_normal, global_step=it)
|
401 |
+
logger.add_scalar('laplacian', loss_laplacian, global_step=it)
|
402 |
+
logger.add_scalar('triangles', loss_triangles, global_step=it)
|
403 |
+
|
404 |
+
|
405 |
+
if it > 1000 and clip_flag:
|
406 |
+
cfg.clip_weight = 0
|
407 |
+
cfg.consistency_loss_weight = 0
|
408 |
+
cfg.regularize_jacobians_weight = 0.025
|
409 |
+
clip_flag = False
|
410 |
+
regularizers = loss_chamfer + loss_edge + loss_normal + loss_laplacian + loss_triangles
|
411 |
+
total_loss = (cfg.clip_weight * garment_loss + cfg.delta_clip_weight * l1_loss +
|
412 |
+
cfg.regularize_jacobians_weight * r_loss + cfg.consistency_loss_weight * consistency_loss + regularizers)
|
413 |
+
|
414 |
+
logger.add_scalar('total_loss', total_loss, global_step=it)
|
415 |
+
|
416 |
+
total_loss.backward()
|
417 |
+
optimizer.step()
|
418 |
+
t_loop.set_description(
|
419 |
+
f'L1 = {cfg.delta_clip_weight * l1_loss.item()}, '
|
420 |
+
f'CLIP = {cfg.clip_weight * garment_loss.item()}, '
|
421 |
+
f'Jacb = {cfg.regularize_jacobians_weight * r_loss.item()}, '
|
422 |
+
f'MVC = {cfg.consistency_loss_weight * consistency_loss.item()}, '
|
423 |
+
f'Chamf = {loss_chamfer.item()}, '
|
424 |
+
f'Edge = {loss_edge.item()}, '
|
425 |
+
f'Normal = {loss_normal.item()}, '
|
426 |
+
f'Lapl = {loss_laplacian.item()}, '
|
427 |
+
f'Triang = {loss_triangles.item()}, '
|
428 |
+
f'Total = {total_loss.item()}')#_target
|
429 |
+
|
430 |
+
video.close()
|
431 |
+
obj.write_obj(
|
432 |
+
str(output_path / 'mesh_final'),
|
433 |
+
m.eval()
|
434 |
+
)
|
435 |
+
|
436 |
+
return
|
main.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
import argparse
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from loop import loop
|
9 |
+
|
10 |
+
def main():
|
11 |
+
parser = argparse.ArgumentParser()
|
12 |
+
parser.add_argument('--config', help='Path to config file', type=str, default='./example_config.yml')
|
13 |
+
parser.add_argument('--output_path', help='Output directory (will be created)', type=str, default=argparse.SUPPRESS)
|
14 |
+
parser.add_argument('--gpu', help='GPU index', type=int, default=argparse.SUPPRESS)
|
15 |
+
parser.add_argument('--seed', help='Random seed', type=int, default=argparse.SUPPRESS)
|
16 |
+
|
17 |
+
# CLIP-related
|
18 |
+
parser.add_argument('--text_prompt', help='Target text prompt', type=str, default=argparse.SUPPRESS)
|
19 |
+
parser.add_argument('--base_text_prompt', help='Base text prompt describing input mesh', type=str, default=argparse.SUPPRESS)
|
20 |
+
parser.add_argument('--clip_model', help='CLIP Model for text comparison', type=str, default=argparse.SUPPRESS)
|
21 |
+
parser.add_argument('--consistency_clip_model', help='CLIP Model for consistency', type=str, default=argparse.SUPPRESS)
|
22 |
+
parser.add_argument('--consistency_vit_stride', help='New stride for ViT patch interpolation', type=int, default=argparse.SUPPRESS)
|
23 |
+
parser.add_argument('--consistency_vit_layer', help='Which layer to take ViT patch features from (0-11)', type=int, default=argparse.SUPPRESS)
|
24 |
+
|
25 |
+
# Mesh
|
26 |
+
parser.add_argument('--mesh', help='Path to input mesh', type=str, default=argparse.SUPPRESS)
|
27 |
+
parser.add_argument('--retriangulate', help='Use isotropic remeshing', type=int, default=argparse.SUPPRESS, choices=[0, 1])
|
28 |
+
|
29 |
+
# Render settings
|
30 |
+
parser.add_argument('--bsdf', help='Render technique', type=str, default=argparse.SUPPRESS, choices=['diffuse', 'pbr'])
|
31 |
+
|
32 |
+
# Hyper-parameters
|
33 |
+
parser.add_argument('--lr', help='Learning rate', type=float, default=argparse.SUPPRESS)
|
34 |
+
parser.add_argument('--epochs', help='Number of optimization steps', type=int, default=argparse.SUPPRESS)
|
35 |
+
parser.add_argument('--clip_weight', help='Weight for CLIP loss', type=float, default=argparse.SUPPRESS)
|
36 |
+
parser.add_argument('--delta_clip_weight', help='Wight for delta-CLIP loss', type=float, default=argparse.SUPPRESS)
|
37 |
+
parser.add_argument('--regularize_jacobians_weight', help='Weight for jacobian regularization', type=float, default=argparse.SUPPRESS)
|
38 |
+
parser.add_argument('--consistency_loss_weight', help='Weight for viewpoint consistency penalty', type=float, default=argparse.SUPPRESS)
|
39 |
+
parser.add_argument('--consistency_elev_filter', help='Elev. angle threshold for filtering out pairs of viewpoints for consistency loss', type=float, default=argparse.SUPPRESS)
|
40 |
+
parser.add_argument('--consistency_azim_filter', help='Azim. angle threshold for filtering out pairs of viewpoints for consistency loss', type=float, default=argparse.SUPPRESS)
|
41 |
+
parser.add_argument('--batch_size', help='Number of images rendered at the same time', type=int, default=argparse.SUPPRESS)
|
42 |
+
parser.add_argument('--train_res', help='Resolution of render before downscaling to CLIP size', type=int, default=argparse.SUPPRESS)
|
43 |
+
parser.add_argument('--resize_method', help='Image downsampling/upsampling method', type=str, default=argparse.SUPPRESS, choices=['cubic', 'linear', 'lanczos2', 'lanczos3'])
|
44 |
+
## Camera Parameters ##
|
45 |
+
parser.add_argument('--fov_min', help='Minimum camera field of view angle during renders', type=float, default=argparse.SUPPRESS)
|
46 |
+
parser.add_argument('--fov_max', help='Maximum camera field of view angle during renders', type=float, default=argparse.SUPPRESS)
|
47 |
+
parser.add_argument('--dist_mi n', help= 'Minimum distance of camera from mesh during renders', type=float, default=argparse.SUPPRESS)
|
48 |
+
parser.add_argument('--dist_max', help='Maximum distance of camera from mesh during renders', type=float, default=argparse.SUPPRESS)
|
49 |
+
parser.add_argument('--light_power', help='Light intensity', type=float, default=argparse.SUPPRESS)
|
50 |
+
parser.add_argument('--elev_alpha', help='Alpha parameter for Beta distribution for elevation sampling', type=float, default=argparse.SUPPRESS)
|
51 |
+
parser.add_argument('--elev_beta', help='Beta parameter for Beta distribution for elevation sampling', type=float, default=argparse.SUPPRESS)
|
52 |
+
parser.add_argument('--elev_max', help='Maximum elevation anglez in degree', type=float, default=argparse.SUPPRESS)
|
53 |
+
parser.add_argument('--azim_min', help='Minimum azimuth angle in degree', type=float, default=argparse.SUPPRESS)
|
54 |
+
parser.add_argument('--azim_max', help='Maximum azimuth angle in degree', type=float, default=argparse.SUPPRESS)
|
55 |
+
parser.add_argument('--aug_loc', help='Offset mesh from center of image?', type=int, default=argparse.SUPPRESS, choices=[0, 1])
|
56 |
+
parser.add_argument('--aug_light', help='Augment the direction of light around the camera', type=int, default=argparse.SUPPRESS, choices=[0, 1])
|
57 |
+
parser.add_argument('--aug_bkg', help='Augment the background', type=int, default=argparse.SUPPRESS, choices=[0, 1])
|
58 |
+
parser.add_argument('--adapt_dist', help='Adjust camera distance to account for scale of shape', type=int, default=argparse.SUPPRESS, choices=[0, 1])
|
59 |
+
|
60 |
+
# Logging
|
61 |
+
parser.add_argument('--log_inter val', help='Interval for logging, every X epochs', type=int, default=argparse.SUPPRESS)
|
62 |
+
parser.add_argument('--log_interval_im', help='Interval for logging renders image, every X epochs', type=int, default=argparse.SUPPRESS)
|
63 |
+
parser.add_argument('--log_elev', help='Logging elevation angle', type=float, default=argparse.SUPPRESS)
|
64 |
+
parser.add_argument('--log_fov', help='Logging field of view', type=float, default=argparse.SUPPRESS)
|
65 |
+
parser.add_argument('--log_dist', help='Logging distance from object', type=float, default=argparse.SUPPRESS)
|
66 |
+
parser.add_argument('--log_res', help='Logging render resolution', type=int, default=argparse.SUPPRESS)
|
67 |
+
parser.add_argument('--log_light_power', help='Light intensity for logging', type=float, default=argparse.SUPPRESS)
|
68 |
+
|
69 |
+
args = parser.parse_args()
|
70 |
+
if args.config is not None:
|
71 |
+
with open(args.config, 'r') as f:
|
72 |
+
try:
|
73 |
+
cfg = yaml.safe_load(f)
|
74 |
+
except yaml.YAMLError as e:
|
75 |
+
print(e)
|
76 |
+
|
77 |
+
for key in vars(args):
|
78 |
+
cfg[key] = vars(args)[key]
|
79 |
+
|
80 |
+
print(yaml.dump(cfg, default_flow_style=False))
|
81 |
+
random.seed(cfg['seed'])
|
82 |
+
os.environ['PYTHONHASHSEED'] = str(cfg['seed'])
|
83 |
+
np.random.seed(cfg['seed'])
|
84 |
+
torch.manual_seed(cfg['seed'])
|
85 |
+
torch.cuda.manual_seed(cfg['seed'])
|
86 |
+
torch.backends.cudnn.deterministic = True
|
87 |
+
|
88 |
+
loop(cfg)
|
89 |
+
print('Done')
|
90 |
+
|
91 |
+
if __name__ == '__main__':
|
92 |
+
main()
|
93 |
+
|
meshes/UV.png
ADDED
![]() |
Git LFS Details
|
meshes/dress_shortsleeve.obj
ADDED
The diff for this file is too large to render.
See raw diff
|
|
meshes/longsleeve.mtl
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Blender MTL File: 'UV_warpings.blend'
|
2 |
+
# Material Count: 1
|
3 |
+
|
4 |
+
newmtl Material1770410.001
|
5 |
+
Ns 250.000000
|
6 |
+
Ka 1.000000 1.000000 1.000000
|
7 |
+
Kd 0.800000 0.800000 0.800000
|
8 |
+
Ks 0.500000 0.500000 0.500000
|
9 |
+
Ke 0.000000 0.000000 0.000000
|
10 |
+
Ni 1.450000
|
11 |
+
d 1.000000
|
12 |
+
illum 2
|
13 |
+
map_Kd UV.png
|
meshes/longsleeve.obj
ADDED
The diff for this file is too large to render.
See raw diff
|
|
meshes/poncho.obj
ADDED
The diff for this file is too large to render.
See raw diff
|
|
meshes/tanktop.obj
ADDED
The diff for this file is too large to render.
See raw diff
|
|
meshes/tshirt.obj
ADDED
The diff for this file is too large to render.
See raw diff
|
|
meshes_target/jacket_sdf_new.obj
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9b224a72b3385ab25aeb86fc357d091db111fb97caf470235b881b4a8116645a
|
3 |
+
size 27668787
|
nvdiffmodeling/LICENSE.txt
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
|
2 |
+
|
3 |
+
|
4 |
+
Nvidia Source Code License (1-Way Commercial)
|
5 |
+
|
6 |
+
=======================================================================
|
7 |
+
|
8 |
+
1. Definitions
|
9 |
+
|
10 |
+
"Licensor" means any person or entity that distributes its Work.
|
11 |
+
|
12 |
+
"Software" means the original work of authorship made available under
|
13 |
+
this License.
|
14 |
+
|
15 |
+
"Work" means the Software and any additions to or derivative works of
|
16 |
+
the Software that are made available under this License.
|
17 |
+
|
18 |
+
The terms "reproduce," "reproduction," "derivative works," and
|
19 |
+
"distribution" have the meaning as provided under U.S. copyright law;
|
20 |
+
provided, however, that for the purposes of this License, derivative
|
21 |
+
works shall not include works that remain separable from, or merely
|
22 |
+
link (or bind by name) to the interfaces of, the Work.
|
23 |
+
|
24 |
+
Works, including the Software, are "made available" under this License
|
25 |
+
by including in or with the Work either (a) a copyright notice
|
26 |
+
referencing the applicability of this License to the Work, or (b) a
|
27 |
+
copy of this License.
|
28 |
+
|
29 |
+
2. License Grants
|
30 |
+
|
31 |
+
2.1 Copyright Grant. Subject to the terms and conditions of this
|
32 |
+
License, each Licensor grants to you a perpetual, worldwide,
|
33 |
+
non-exclusive, royalty-free, copyright license to reproduce,
|
34 |
+
prepare derivative works of, publicly display, publicly perform,
|
35 |
+
sublicense and distribute its Work and any resulting derivative
|
36 |
+
works in any form.
|
37 |
+
|
38 |
+
3. Limitations
|
39 |
+
|
40 |
+
3.1 Redistribution. You may reproduce or distribute the Work only
|
41 |
+
if (a) you do so under this License, (b) you include a complete
|
42 |
+
copy of this License with your distribution, and (c) you retain
|
43 |
+
without modification any copyright, patent, trademark, or
|
44 |
+
attribution notices that are present in the Work.
|
45 |
+
|
46 |
+
3.2 Derivative Works. You may specify that additional or different
|
47 |
+
terms apply to the use, reproduction, and distribution of your
|
48 |
+
derivative works of the Work ("Your Terms") only if (a) Your Terms
|
49 |
+
provide that the use limitation in Section 3.3 applies to your
|
50 |
+
derivative works, and (b) you identify the specific derivative
|
51 |
+
works that are subject to Your Terms. Notwithstanding Your Terms,
|
52 |
+
this License (including the redistribution requirements in Section
|
53 |
+
3.1) will continue to apply to the Work itself.
|
54 |
+
|
55 |
+
3.3 Use Limitation. The Work and any derivative works thereof only
|
56 |
+
may be used or intended for use non-commercially. The Work or
|
57 |
+
derivative works thereof may be used or intended for use by Nvidia
|
58 |
+
or its affiliates commercially or non-commercially. As used herein,
|
59 |
+
"non-commercially" means for research or evaluation purposes only
|
60 |
+
and not for any direct or indirect monetary gain.
|
61 |
+
|
62 |
+
3.4 Patent Claims. If you bring or threaten to bring a patent claim
|
63 |
+
against any Licensor (including any claim, cross-claim or
|
64 |
+
counterclaim in a lawsuit) to enforce any patents that you allege
|
65 |
+
are infringed by any Work, then your rights under this License from
|
66 |
+
such Licensor (including the grant in Section 2.1) will terminate
|
67 |
+
immediately.
|
68 |
+
|
69 |
+
3.5 Trademarks. This License does not grant any rights to use any
|
70 |
+
Licensor's or its affiliates' names, logos, or trademarks, except
|
71 |
+
as necessary to reproduce the notices described in this License.
|
72 |
+
|
73 |
+
3.6 Termination. If you violate any term of this License, then your
|
74 |
+
rights under this License (including the grant in Section 2.1) will
|
75 |
+
terminate immediately.
|
76 |
+
|
77 |
+
4. Disclaimer of Warranty.
|
78 |
+
|
79 |
+
THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
80 |
+
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
|
81 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
|
82 |
+
NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
|
83 |
+
THIS LICENSE.
|
84 |
+
|
85 |
+
5. Limitation of Liability.
|
86 |
+
|
87 |
+
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
|
88 |
+
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
|
89 |
+
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
|
90 |
+
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
|
91 |
+
OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
|
92 |
+
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
|
93 |
+
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
|
94 |
+
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
|
95 |
+
THE POSSIBILITY OF SUCH DAMAGES.
|
96 |
+
|
97 |
+
=======================================================================
|
nvdiffmodeling/src/material.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import os
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from . import util
|
14 |
+
from . import texture
|
15 |
+
from . import mesh
|
16 |
+
|
17 |
+
######################################################################################
|
18 |
+
# .mtl material format loading / storing
|
19 |
+
######################################################################################
|
20 |
+
|
21 |
+
def load_mtl(fn, clear_ks=True):
|
22 |
+
import re
|
23 |
+
mtl_path = os.path.dirname(fn)
|
24 |
+
|
25 |
+
# Read file
|
26 |
+
with open(fn) as f:
|
27 |
+
lines = f.readlines()
|
28 |
+
|
29 |
+
# Parse materials
|
30 |
+
materials = []
|
31 |
+
for line in lines:
|
32 |
+
split_line = re.split(' +|\t+|\n+', line.strip())
|
33 |
+
prefix = split_line[0].lower()
|
34 |
+
data = split_line[1:]
|
35 |
+
if 'newmtl' in prefix:
|
36 |
+
material = {'name' : data[0]}
|
37 |
+
materials += [material]
|
38 |
+
elif materials:
|
39 |
+
if 'bsdf' in prefix or 'map_kd' in prefix or 'map_ks' in prefix or 'bump' in prefix:
|
40 |
+
material[prefix] = data[0]
|
41 |
+
else:
|
42 |
+
material[prefix] = torch.tensor(tuple(float(d) for d in data), dtype=torch.float32, device='cuda')
|
43 |
+
|
44 |
+
# Convert everything to textures. Our code expects 'kd' and 'ks' to be texture maps. So replace constants with 1x1 maps
|
45 |
+
for mat in materials:
|
46 |
+
if not 'bsdf' in mat:
|
47 |
+
mat['bsdf'] = 'pbr'
|
48 |
+
|
49 |
+
if 'map_kd' in mat:
|
50 |
+
mat['kd'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_kd']))
|
51 |
+
else:
|
52 |
+
mat['kd'] = texture.Texture2D(mat['kd'])
|
53 |
+
|
54 |
+
if 'map_ks' in mat:
|
55 |
+
mat['ks'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_ks']), channels=3)
|
56 |
+
else:
|
57 |
+
mat['ks'] = texture.Texture2D(mat['ks'])
|
58 |
+
|
59 |
+
if 'bump' in mat:
|
60 |
+
mat['normal'] = texture.load_texture2D(os.path.join(mtl_path, mat['bump']), lambda_fn=lambda x: x * 2 - 1, channels=3)
|
61 |
+
|
62 |
+
# Convert Kd from sRGB to linear RGB
|
63 |
+
mat['kd'] = texture.srgb_to_rgb(mat['kd'])
|
64 |
+
|
65 |
+
if clear_ks:
|
66 |
+
# Override ORM occlusion (red) channel by zeros. We hijack this channel
|
67 |
+
for mip in mat['ks'].getMips():
|
68 |
+
mip[..., 0] = 0.0
|
69 |
+
|
70 |
+
return materials
|
71 |
+
|
72 |
+
def save_mtl(fn, material):
|
73 |
+
folder = os.path.dirname(fn)
|
74 |
+
with open(fn, "w") as f:
|
75 |
+
f.write('newmtl defaultMat\n')
|
76 |
+
if material is not None:
|
77 |
+
f.write('bsdf %s\n' % material['bsdf'])
|
78 |
+
f.write('map_kd texture_kd.png\n')
|
79 |
+
texture.save_texture2D(os.path.join(folder, 'texture_kd.png'), texture.rgb_to_srgb(material['kd']))
|
80 |
+
f.write('map_ks texture_ks.png\n')
|
81 |
+
texture.save_texture2D(os.path.join(folder, 'texture_ks.png'), material['ks'])
|
82 |
+
f.write('bump texture_n.png\n')
|
83 |
+
texture.save_texture2D(os.path.join(folder, 'texture_n.png'), material['normal'], lambda_fn=lambda x:(x+1)*0.5)
|
84 |
+
else:
|
85 |
+
f.write('Kd 1 1 1\n')
|
86 |
+
f.write('Ks 0 0 0\n')
|
87 |
+
f.write('Ka 0 0 0\n')
|
88 |
+
f.write('Tf 1 1 1\n')
|
89 |
+
f.write('Ni 1\n')
|
90 |
+
f.write('Ns 0\n')
|
91 |
+
|
92 |
+
######################################################################################
|
93 |
+
# Merge multiple materials into a single uber-material
|
94 |
+
######################################################################################
|
95 |
+
|
96 |
+
def _upscale_replicate(x, full_res):
|
97 |
+
x = x.permute(0, 3, 1, 2)
|
98 |
+
x = torch.nn.functional.pad(x, (0, full_res[1] - x.shape[3], 0, full_res[0] - x.shape[2]), 'replicate')
|
99 |
+
return x.permute(0, 2, 3, 1).contiguous()
|
100 |
+
|
101 |
+
def merge_materials(materials, texcoords, tfaces, mfaces):
|
102 |
+
assert len(materials) > 0
|
103 |
+
for mat in materials:
|
104 |
+
assert mat['bsdf'] == materials[0]['bsdf'], "All materials must have the same BSDF (uber shader)"
|
105 |
+
assert ('normal' in mat) is ('normal' in materials[0]), "All materials must have either normal map enabled or disabled"
|
106 |
+
|
107 |
+
uber_material = {
|
108 |
+
'name' : 'uber_material',
|
109 |
+
'bsdf' : materials[0]['bsdf'],
|
110 |
+
}
|
111 |
+
|
112 |
+
textures = ['kd', 'ks', 'normal']
|
113 |
+
|
114 |
+
# Find maximum texture resolution across all materials and textures
|
115 |
+
max_res = None
|
116 |
+
for mat in materials:
|
117 |
+
for tex in textures:
|
118 |
+
tex_res = np.array(mat[tex].getRes()) if tex in mat else np.array([1, 1])
|
119 |
+
max_res = np.maximum(max_res, tex_res) if max_res is not None else tex_res
|
120 |
+
|
121 |
+
# Compute size of compund texture and round up to nearest PoT
|
122 |
+
full_res = 2**np.ceil(np.log2(max_res * np.array([1, len(materials)]))).astype(np.int)
|
123 |
+
|
124 |
+
# Normalize texture resolution across all materials & combine into a single large texture
|
125 |
+
for tex in textures:
|
126 |
+
if tex in materials[0]:
|
127 |
+
tex_data = torch.cat(tuple(util.scale_img_nhwc(mat[tex].data, tuple(max_res)) for mat in materials), dim=2) # Lay out all textures horizontally, NHWC so dim2 is x
|
128 |
+
tex_data = _upscale_replicate(tex_data, full_res)
|
129 |
+
uber_material[tex] = texture.Texture2D(tex_data)
|
130 |
+
|
131 |
+
# Compute scaling values for used / unused texture area
|
132 |
+
s_coeff = [full_res[0] / max_res[0], full_res[1] / max_res[1]]
|
133 |
+
|
134 |
+
# Recompute texture coordinates to cooincide with new composite texture
|
135 |
+
new_tverts = {}
|
136 |
+
new_tverts_data = []
|
137 |
+
for fi in range(len(tfaces)):
|
138 |
+
matIdx = mfaces[fi]
|
139 |
+
for vi in range(3):
|
140 |
+
ti = tfaces[fi][vi]
|
141 |
+
if not (ti in new_tverts):
|
142 |
+
new_tverts[ti] = {}
|
143 |
+
if not (matIdx in new_tverts[ti]): # create new vertex
|
144 |
+
new_tverts_data.append([(matIdx + texcoords[ti][0]) / s_coeff[1], texcoords[ti][1] / s_coeff[0]]) # Offset texture coodrinate (x direction) by material id & scale to local space. Note, texcoords are (u,v) but texture is stored (w,h) so the indexes swap here
|
145 |
+
new_tverts[ti][matIdx] = len(new_tverts_data) - 1
|
146 |
+
tfaces[fi][vi] = new_tverts[ti][matIdx] # reindex vertex
|
147 |
+
|
148 |
+
return uber_material, new_tverts_data, tfaces
|
149 |
+
|
nvdiffmodeling/src/mesh.py
ADDED
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import os
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from . import util
|
14 |
+
from . import texture
|
15 |
+
|
16 |
+
######################################################################################
|
17 |
+
# Base mesh class
|
18 |
+
######################################################################################
|
19 |
+
class Mesh:
|
20 |
+
def __init__(self, v_pos=None, t_pos_idx=None, v_nrm=None, t_nrm_idx=None, v_tex=None, t_tex_idx=None, v_tng=None, t_tng_idx=None,
|
21 |
+
v_weights=None, bone_mtx=None, material=None, base=None):
|
22 |
+
self.v_pos = v_pos
|
23 |
+
self.v_weights = v_weights
|
24 |
+
self.v_nrm = v_nrm
|
25 |
+
self.v_tex = v_tex
|
26 |
+
self.v_tng = v_tng
|
27 |
+
self.t_pos_idx = t_pos_idx
|
28 |
+
self.t_nrm_idx = t_nrm_idx
|
29 |
+
self.t_tex_idx = t_tex_idx
|
30 |
+
self.t_tng_idx = t_tng_idx
|
31 |
+
self.material = material
|
32 |
+
self.bone_mtx = bone_mtx
|
33 |
+
|
34 |
+
if base is not None:
|
35 |
+
self.copy_none(base)
|
36 |
+
|
37 |
+
def copy_none(self, other):
|
38 |
+
if self.v_pos is None:
|
39 |
+
self.v_pos = other.v_pos
|
40 |
+
if self.v_weights is None:
|
41 |
+
self.v_weights = other.v_weights
|
42 |
+
if self.t_pos_idx is None:
|
43 |
+
self.t_pos_idx = other.t_pos_idx
|
44 |
+
if self.v_nrm is None:
|
45 |
+
self.v_nrm = other.v_nrm
|
46 |
+
if self.t_nrm_idx is None:
|
47 |
+
self.t_nrm_idx = other.t_nrm_idx
|
48 |
+
if self.v_tex is None:
|
49 |
+
self.v_tex = other.v_tex
|
50 |
+
if self.t_tex_idx is None:
|
51 |
+
self.t_tex_idx = other.t_tex_idx
|
52 |
+
if self.v_tng is None:
|
53 |
+
self.v_tng = other.v_tng
|
54 |
+
if self.t_tng_idx is None:
|
55 |
+
self.t_tng_idx = other.t_tng_idx
|
56 |
+
if self.material is None:
|
57 |
+
self.material = other.material
|
58 |
+
if self.bone_mtx is None:
|
59 |
+
self.bone_mtx = other.bone_mtx
|
60 |
+
|
61 |
+
def get_frames(self):
|
62 |
+
return self.bone_mtx.shape[0] if self.bone_mtx is not None else 1
|
63 |
+
|
64 |
+
def clone(self):
|
65 |
+
out = Mesh(base=self)
|
66 |
+
if out.v_pos is not None:
|
67 |
+
out.v_pos = out.v_pos.clone()
|
68 |
+
if out.v_weights is not None:
|
69 |
+
out.v_weights = out.v_weights.clone()
|
70 |
+
if out.t_pos_idx is not None:
|
71 |
+
out.t_pos_idx = out.t_pos_idx.clone()
|
72 |
+
if out.v_nrm is not None:
|
73 |
+
out.v_nrm = out.v_nrm.clone()
|
74 |
+
if out.t_nrm_idx is not None:
|
75 |
+
out.t_nrm_idx = out.t_nrm_idx.clone()
|
76 |
+
if out.v_tex is not None:
|
77 |
+
out.v_tex = out.v_tex.clone()
|
78 |
+
if out.t_tex_idx is not None:
|
79 |
+
out.t_tex_idx = out.t_tex_idx.clone()
|
80 |
+
if out.v_tng is not None:
|
81 |
+
out.v_tng = out.v_tng.clone()
|
82 |
+
if out.t_tng_idx is not None:
|
83 |
+
out.t_tng_idx = out.t_tng_idx.clone()
|
84 |
+
if out.bone_mtx is not None:
|
85 |
+
out.bone_mtx = out.bone_mtx.clone()
|
86 |
+
return out
|
87 |
+
|
88 |
+
def eval(self, params={}):
|
89 |
+
return self
|
90 |
+
|
91 |
+
######################################################################################
|
92 |
+
# Compute AABB
|
93 |
+
######################################################################################
|
94 |
+
def aabb(mesh):
|
95 |
+
return torch.min(mesh.v_pos, dim=0).values, torch.max(mesh.v_pos, dim=0).values
|
96 |
+
|
97 |
+
######################################################################################
|
98 |
+
# Align base mesh to reference mesh:move & rescale to match bounding boxes.
|
99 |
+
######################################################################################
|
100 |
+
def unit_size(mesh):
|
101 |
+
with torch.no_grad():
|
102 |
+
vmin, vmax = aabb(mesh)
|
103 |
+
scale = 2 / torch.max(vmax - vmin).item()
|
104 |
+
v_pos = mesh.v_pos - (vmax + vmin) / 2 # Center mesh on origin
|
105 |
+
v_pos = v_pos * scale # Rescale to unit size
|
106 |
+
|
107 |
+
return Mesh(v_pos, base=mesh)
|
108 |
+
|
109 |
+
def resize_mesh(mesh):
|
110 |
+
scale = 0.03234645293868976
|
111 |
+
vmax = torch.tensor([ 32.9707, 159.2754, 16.8091], device='cuda:0')
|
112 |
+
vmin = torch.tensor([-28.7435, 97.4448, -18.4702], device='cuda:0')
|
113 |
+
with torch.no_grad():
|
114 |
+
v_pos = (mesh.v_pos/scale) + (vmax + vmin) / 2
|
115 |
+
return Mesh(v_pos, base=mesh)
|
116 |
+
|
117 |
+
|
118 |
+
######################################################################################
|
119 |
+
# Center & scale mesh for rendering
|
120 |
+
#
|
121 |
+
# TODO: It should be better to compute camera position from animated reference mesh
|
122 |
+
# instead of centering and scaling all meshes
|
123 |
+
######################################################################################
|
124 |
+
def center_by_reference(base_mesh, ref_aabb, scale):
|
125 |
+
center = (ref_aabb[0] + ref_aabb[1]) * 0.5
|
126 |
+
scale = scale / torch.max(ref_aabb[1] - ref_aabb[0]).item()
|
127 |
+
v_pos = (base_mesh.v_pos - center[None, ...]) * scale
|
128 |
+
return Mesh(v_pos, base=base_mesh)
|
129 |
+
|
130 |
+
######################################################################################
|
131 |
+
# Rescale base-mesh from NDC [-1, 1] space to same dimensions as reference mesh
|
132 |
+
######################################################################################
|
133 |
+
def align_with_reference(base_mesh, ref_mesh): # TODO: Fix normals?
|
134 |
+
class mesh_op_align:
|
135 |
+
def __init__(self, base_mesh, ref_mesh):
|
136 |
+
self.base_mesh = base_mesh
|
137 |
+
with torch.no_grad():
|
138 |
+
b_vmin, b_vmax = aabb(base_mesh.eval())
|
139 |
+
r_vmin, r_vmax = aabb(ref_mesh.eval())
|
140 |
+
b_size = (b_vmax - b_vmin)
|
141 |
+
self.offset = (r_vmax + r_vmin) / 2
|
142 |
+
self.scale = (r_vmax - r_vmin) / torch.where(b_size > 1e-6, b_size, torch.ones_like(b_size))
|
143 |
+
|
144 |
+
def eval(self, params={}):
|
145 |
+
base_mesh = self.base_mesh.eval(params)
|
146 |
+
v_pos = base_mesh.v_pos * self.scale[None, ...] + self.offset[None, ...]
|
147 |
+
return Mesh(v_pos, base=base_mesh)
|
148 |
+
|
149 |
+
return mesh_op_align(base_mesh, ref_mesh)
|
150 |
+
|
151 |
+
######################################################################################
|
152 |
+
# Skinning
|
153 |
+
######################################################################################
|
154 |
+
|
155 |
+
# Helper function to skin homogeneous vectors
|
156 |
+
def _skin_hvec(bone_mtx, weights, attr):
|
157 |
+
attr_out = torch.matmul(attr[None, ...], bone_mtx) * torch.transpose(weights, 0, 1)[..., None]
|
158 |
+
return attr_out.sum(dim=0)[:, :3]
|
159 |
+
|
160 |
+
def skinning(mesh):
|
161 |
+
class mesh_op_skinning:
|
162 |
+
def __init__(self, input):
|
163 |
+
self.input = input
|
164 |
+
|
165 |
+
mesh = self.input.eval()
|
166 |
+
t_pos_idx = mesh.t_pos_idx.detach().cpu().numpy()
|
167 |
+
if mesh.t_nrm_idx is not None:
|
168 |
+
self.nrm_remap = self._compute_remap(t_pos_idx, mesh.v_nrm.shape[0], mesh.t_nrm_idx.detach().cpu().numpy())
|
169 |
+
if mesh.t_tng_idx is not None:
|
170 |
+
self.tng_remap = self._compute_remap(t_pos_idx, mesh.v_tng.shape[0], mesh.t_tng_idx.detach().cpu().numpy())
|
171 |
+
|
172 |
+
# Compute an index list with corresponding vertex index for each normal/tangent. Vertices may have multiple normals/tangents, but not the other way around
|
173 |
+
def _compute_remap(self, t_pos_idx, n_attrs, t_attr_idx):
|
174 |
+
assert len(t_pos_idx) == len(t_attr_idx)
|
175 |
+
|
176 |
+
attr_vtx_idx = [None] * n_attrs
|
177 |
+
for ti in range(0, len(t_pos_idx)):
|
178 |
+
for vi in range(0, 3):
|
179 |
+
assert attr_vtx_idx[t_attr_idx[ti][vi]] is None or attr_vtx_idx[t_attr_idx[ti][vi]] == t_pos_idx[ti][vi], "Trying to skin a mesh with shared normals (normal with 2 sets of skinning weights)"
|
180 |
+
attr_vtx_idx[t_attr_idx[ti][vi]] = t_pos_idx[ti][vi]
|
181 |
+
|
182 |
+
return torch.tensor(attr_vtx_idx, dtype=torch.int64, device='cuda')
|
183 |
+
|
184 |
+
def eval(self, params={}):
|
185 |
+
imesh = self.input.eval(params)
|
186 |
+
|
187 |
+
if imesh.v_weights is None or imesh.bone_mtx is None:
|
188 |
+
return imesh
|
189 |
+
|
190 |
+
# Compute frame (assume looping animation). Note, bone_mtx is stored [Frame, Bone, ...]
|
191 |
+
t_idx = params['time'] if 'time' in params else 0
|
192 |
+
t_idx = (t_idx % imesh.bone_mtx.shape[0]) # Loop animation
|
193 |
+
bone_mtx = imesh.bone_mtx[t_idx, ...]
|
194 |
+
bone_mtx_it = torch.transpose(torch.inverse(bone_mtx), -2, -1)
|
195 |
+
|
196 |
+
weights = imesh.v_weights
|
197 |
+
assert weights.shape[1] == bone_mtx.shape[0]
|
198 |
+
|
199 |
+
# Normalize weights
|
200 |
+
weights = torch.abs(weights) # TODO: This stabilizes training, but I don't know why. All weights are already clamped to >0
|
201 |
+
weights = weights / torch.sum(weights, dim=1, keepdim=True)
|
202 |
+
|
203 |
+
# Skin position
|
204 |
+
v_pos_out = _skin_hvec(bone_mtx, weights, util.to_hvec(imesh.v_pos, 1))
|
205 |
+
|
206 |
+
# Skin normal
|
207 |
+
v_nrm_out = None
|
208 |
+
if imesh.v_nrm is not None:
|
209 |
+
v_nrm_out = _skin_hvec(bone_mtx_it, weights[self.nrm_remap, ...], util.to_hvec(imesh.v_nrm, 0))
|
210 |
+
v_nrm_out = util.safe_normalize(v_nrm_out)
|
211 |
+
|
212 |
+
# Skin tangent
|
213 |
+
v_tng_out = None
|
214 |
+
if imesh.v_tng is not None:
|
215 |
+
v_tng_out = _skin_hvec(bone_mtx, weights[self.tng_remap, ...], util.to_hvec(imesh.v_tng, 0))
|
216 |
+
v_tng_out = util.safe_normalize(v_tng_out)
|
217 |
+
|
218 |
+
if torch.is_anomaly_enabled():
|
219 |
+
assert torch.all(torch.isfinite(v_pos_out))
|
220 |
+
assert v_nrm_out is None or torch.all(torch.isfinite(v_nrm_out))
|
221 |
+
assert v_tng_out is None or torch.all(torch.isfinite(v_tng_out))
|
222 |
+
|
223 |
+
return Mesh(v_pos=v_pos_out[:, :3], v_nrm=v_nrm_out, v_tng=v_tng_out, base=imesh)
|
224 |
+
|
225 |
+
return mesh_op_skinning(mesh)
|
226 |
+
|
227 |
+
# Skinning helper functions
|
228 |
+
def guess_weights(base_mesh, ref_mesh, N=10):
|
229 |
+
base_v_pos = base_mesh.v_pos.detach().cpu().numpy()
|
230 |
+
ref_v_pos = ref_mesh.v_pos.detach().cpu().numpy()
|
231 |
+
ref_v_weights = ref_mesh.v_weights.detach().cpu().numpy()
|
232 |
+
base_v_weights = np.zeros((base_v_pos.shape[0], ref_v_weights.shape[1]), dtype=np.float32)
|
233 |
+
|
234 |
+
for v_idx, vtx in enumerate(base_v_pos):
|
235 |
+
# Compute distance from current vertex to vertices in ref_mesh
|
236 |
+
diff = ref_v_pos - vtx[None, ...]
|
237 |
+
dist = np.sum(diff * diff, axis=-1)
|
238 |
+
idxs = np.argpartition(dist, N)
|
239 |
+
|
240 |
+
# Get the N nearest vertices
|
241 |
+
sum_w = 0.0
|
242 |
+
sum_vtx_w = np.zeros_like(ref_v_weights[0,...])
|
243 |
+
for i in idxs[:N]:
|
244 |
+
sum_w += 1.0 / max(dist[i], 0.001)
|
245 |
+
sum_vtx_w += ref_v_weights[i, ...] / max(dist[i], 0.001)
|
246 |
+
base_v_weights[v_idx, ...] = sum_vtx_w / sum_w
|
247 |
+
|
248 |
+
return base_v_weights
|
249 |
+
|
250 |
+
def random_weights(base_mesh, ref_mesh):
|
251 |
+
init = np.random.uniform(size=(base_mesh.v_pos.shape[0], ref_mesh.v_weights.shape[1]), low=0.0, high=1.0)
|
252 |
+
return init / np.sum(init, axis=1, keepdims=True)
|
253 |
+
|
254 |
+
|
255 |
+
######################################################################################
|
256 |
+
# Simple smooth vertex normal computation
|
257 |
+
######################################################################################
|
258 |
+
def auto_normals(mesh):
|
259 |
+
class mesh_op_auto_normals:
|
260 |
+
def __init__(self, input):
|
261 |
+
self.input = input
|
262 |
+
|
263 |
+
def eval(self, params={}):
|
264 |
+
imesh = self.input.eval(params)
|
265 |
+
|
266 |
+
i0 = imesh.t_pos_idx[:, 0]
|
267 |
+
i1 = imesh.t_pos_idx[:, 1]
|
268 |
+
i2 = imesh.t_pos_idx[:, 2]
|
269 |
+
|
270 |
+
v0 = imesh.v_pos[i0, :]
|
271 |
+
v1 = imesh.v_pos[i1, :]
|
272 |
+
v2 = imesh.v_pos[i2, :]
|
273 |
+
|
274 |
+
face_normals = torch.cross(v1 - v0, v2 - v0)
|
275 |
+
|
276 |
+
# Splat face normals to vertices
|
277 |
+
v_nrm = torch.zeros_like(imesh.v_pos)
|
278 |
+
v_nrm.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)
|
279 |
+
v_nrm.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)
|
280 |
+
v_nrm.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)
|
281 |
+
|
282 |
+
# Normalize, replace zero (degenerated) normals with some default value
|
283 |
+
v_nrm = torch.where(util.dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device='cuda'))
|
284 |
+
|
285 |
+
self.v_nrm = util.safe_normalize(v_nrm)
|
286 |
+
|
287 |
+
if torch.is_anomaly_enabled():
|
288 |
+
assert torch.all(torch.isfinite(self.v_nrm))
|
289 |
+
|
290 |
+
return Mesh(v_nrm = self.v_nrm, t_nrm_idx=imesh.t_pos_idx, base = imesh)
|
291 |
+
|
292 |
+
return mesh_op_auto_normals(mesh)
|
293 |
+
|
294 |
+
######################################################################################
|
295 |
+
# Compute tangent space from texture map coordinates
|
296 |
+
# Follows http://www.mikktspace.com/ conventions
|
297 |
+
######################################################################################
|
298 |
+
def compute_tangents(mesh):
|
299 |
+
class mesh_op_compute_tangents:
|
300 |
+
def __init__(self, input):
|
301 |
+
self.input = input
|
302 |
+
|
303 |
+
def eval(self, params={}):
|
304 |
+
imesh = self.input.eval(params)
|
305 |
+
|
306 |
+
vn_idx = [None] * 3
|
307 |
+
pos = [None] * 3
|
308 |
+
tex = [None] * 3
|
309 |
+
for i in range(0,3):
|
310 |
+
pos[i] = imesh.v_pos[imesh.t_pos_idx[:, i]]
|
311 |
+
tex[i] = imesh.v_tex[imesh.t_tex_idx[:, i]]
|
312 |
+
vn_idx[i] = imesh.t_nrm_idx[:, i]
|
313 |
+
|
314 |
+
tangents = torch.zeros_like(imesh.v_nrm)
|
315 |
+
tansum = torch.zeros_like(imesh.v_nrm)
|
316 |
+
|
317 |
+
# Compute tangent space for each triangle
|
318 |
+
uve1 = tex[1] - tex[0]
|
319 |
+
uve2 = tex[2] - tex[0]
|
320 |
+
pe1 = pos[1] - pos[0]
|
321 |
+
pe2 = pos[2] - pos[0]
|
322 |
+
|
323 |
+
nom = (pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2])
|
324 |
+
denom = (uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1])
|
325 |
+
|
326 |
+
# Avoid division by zero for degenerated texture coordinates
|
327 |
+
tang = nom / torch.where(denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6))
|
328 |
+
|
329 |
+
# Update all 3 vertices
|
330 |
+
for i in range(0,3):
|
331 |
+
idx = vn_idx[i][:, None].repeat(1,3)
|
332 |
+
tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
|
333 |
+
tansum.scatter_add_(0, idx, torch.ones_like(tang)) # tansum[n_i] = tansum[n_i] + 1
|
334 |
+
tangents = tangents / tansum
|
335 |
+
|
336 |
+
# Normalize and make sure tangent is perpendicular to normal
|
337 |
+
tangents = util.safe_normalize(tangents)
|
338 |
+
tangents = util.safe_normalize(tangents - util.dot(tangents, imesh.v_nrm) * imesh.v_nrm)
|
339 |
+
|
340 |
+
self.v_tng = tangents
|
341 |
+
|
342 |
+
if torch.is_anomaly_enabled():
|
343 |
+
assert torch.all(torch.isfinite(tangents))
|
344 |
+
|
345 |
+
return Mesh(v_tng=self.v_tng, t_tng_idx=imesh.t_nrm_idx, base=imesh)
|
346 |
+
|
347 |
+
return mesh_op_compute_tangents(mesh)
|
348 |
+
|
349 |
+
######################################################################################
|
350 |
+
# Subdivide each triangle into 4 new ones. Edge midpoint subdivision
|
351 |
+
######################################################################################
|
352 |
+
def subdivide(mesh, steps=1):
|
353 |
+
class mesh_op_subdivide:
|
354 |
+
def __init__(self, input):
|
355 |
+
self.input = input
|
356 |
+
self.new_vtx_idx = [None] * 4
|
357 |
+
self.new_tri_idx = [None] * 4
|
358 |
+
|
359 |
+
imesh = self.input.eval()
|
360 |
+
v_attr = v_attr_orig = [imesh.v_pos, imesh.v_nrm, imesh.v_tex, imesh.v_tng]
|
361 |
+
v_idx = v_idx_orig = [imesh.t_pos_idx, imesh.t_nrm_idx, imesh.t_tex_idx, imesh.t_tng_idx]
|
362 |
+
|
363 |
+
for i, attr in enumerate(v_attr):
|
364 |
+
if attr is not None:
|
365 |
+
tri_idx = v_idx[i].cpu().numpy()
|
366 |
+
|
367 |
+
# Find unique edges
|
368 |
+
edge_fetch_a = []
|
369 |
+
edge_fetch_b = []
|
370 |
+
edge_verts = {}
|
371 |
+
for tri in tri_idx:
|
372 |
+
for e_idx in range(0, 3):
|
373 |
+
v0 = tri[e_idx]
|
374 |
+
v1 = tri[(e_idx + 1) % 3]
|
375 |
+
if (v1, v0) not in edge_verts.keys():
|
376 |
+
edge_verts[(v0, v1)] = [len(edge_fetch_a), v0, v1]
|
377 |
+
edge_fetch_a += [v0]
|
378 |
+
edge_fetch_b += [v1]
|
379 |
+
|
380 |
+
# Create vertex fetch lists for computing midpoint vertices
|
381 |
+
self.new_vtx_idx[i] = [torch.tensor(edge_fetch_a, dtype=torch.int64, device='cuda'), torch.tensor(edge_fetch_b, dtype=torch.int64, device='cuda')]
|
382 |
+
|
383 |
+
# Create subdivided triangles
|
384 |
+
new_tri_idx = []
|
385 |
+
for tri in tri_idx:
|
386 |
+
v0, v1, v2= tri
|
387 |
+
h0 = (edge_verts[(v0, v1)][0] if (v0, v1) in edge_verts.keys() else edge_verts[(v1, v0)][0]) + attr.shape[0]
|
388 |
+
h1 = (edge_verts[(v1, v2)][0] if (v1, v2) in edge_verts.keys() else edge_verts[(v2, v1)][0]) + attr.shape[0]
|
389 |
+
h2 = (edge_verts[(v2, v0)][0] if (v2, v0) in edge_verts.keys() else edge_verts[(v0, v2)][0]) + attr.shape[0]
|
390 |
+
new_tri_idx += [[v0, h0, h2], [h0, v1, h1], [h1, v2, h2], [h0, h1, h2]]
|
391 |
+
self.new_tri_idx[i] = torch.tensor(new_tri_idx, dtype=torch.int64, device='cuda')
|
392 |
+
|
393 |
+
def eval(self, params={}):
|
394 |
+
imesh = self.input.eval(params)
|
395 |
+
|
396 |
+
v_attr = v_attr_orig = [imesh.v_pos, imesh.v_nrm, imesh.v_tex, imesh.v_tng]
|
397 |
+
v_idx = v_idx_orig = [imesh.t_pos_idx, imesh.t_nrm_idx, imesh.t_tex_idx, imesh.t_tng_idx]
|
398 |
+
|
399 |
+
for i, attr in enumerate(v_attr):
|
400 |
+
if attr is not None:
|
401 |
+
# Create new edge midpoint attributes
|
402 |
+
edge_attr = (attr[self.new_vtx_idx[i][0], :] + attr[self.new_vtx_idx[i][1], :]) * 0.5
|
403 |
+
v_attr[i] = torch.cat([attr, edge_attr], dim=0)
|
404 |
+
|
405 |
+
# Copy new triangle lists
|
406 |
+
v_idx[i] = self.new_tri_idx[i]
|
407 |
+
|
408 |
+
return Mesh(v_attr[0], v_idx[0], v_attr[1], v_idx[1], v_attr[2], v_idx[2], v_attr[3], v_idx[3], base=imesh)
|
409 |
+
|
410 |
+
x = mesh
|
411 |
+
for i in range(steps):
|
412 |
+
x = mesh_op_subdivide(x)
|
413 |
+
|
414 |
+
bm = mesh.eval()
|
415 |
+
sm = x.eval()
|
416 |
+
v_attr_orig = [bm.v_pos, bm.v_nrm, bm.v_tex, bm.v_tng]
|
417 |
+
v_attr = [sm.v_pos, sm.v_nrm, sm.v_tex, sm.v_tng]
|
418 |
+
v_idx_orig = [bm.t_pos_idx, bm.t_nrm_idx, bm.t_tex_idx, bm.t_tng_idx]
|
419 |
+
v_idx = [sm.t_pos_idx, sm.t_nrm_idx, sm.t_tex_idx, sm.t_tng_idx]
|
420 |
+
print("Subdivided mesh:")
|
421 |
+
print(" Attrs: [%6d, %6d, %6d, %6d] -> [%6d, %6d, %6d, %6d]" % tuple(list((a.shape[0] if a is not None else 0) for a in v_attr_orig) + list((a.shape[0] if a is not None else 0) for a in v_attr)))
|
422 |
+
print(" Indices: [%6d, %6d, %6d, %6d] -> [%6d, %6d, %6d, %6d]" % tuple(list((a.shape[0] if a is not None else 0) for a in v_idx_orig) + list((a.shape[0] if a is not None else 0) for a in v_idx)))
|
423 |
+
|
424 |
+
return x
|
425 |
+
|
426 |
+
######################################################################################
|
427 |
+
# Displacement mapping
|
428 |
+
######################################################################################
|
429 |
+
def displace(mesh, displacement_map, scale=1.0, keep_connectivity=True):
|
430 |
+
class mesh_op_displace:
|
431 |
+
def __init__(self, input, displacement_map, scale, keep_connectivity):
|
432 |
+
self.input = input
|
433 |
+
self.displacement_map = displacement_map
|
434 |
+
self.scale = scale
|
435 |
+
self.keep_connectivity = keep_connectivity
|
436 |
+
|
437 |
+
def eval(self, params={}):
|
438 |
+
imesh = self.input.eval(params)
|
439 |
+
|
440 |
+
if self.keep_connectivity:
|
441 |
+
vd = torch.zeros_like(imesh.v_pos)
|
442 |
+
vd_n = torch.zeros_like(imesh.v_pos)
|
443 |
+
for i in range(0, 3):
|
444 |
+
v = imesh.v_pos[imesh.t_pos_idx[:, i], :]
|
445 |
+
n = imesh.v_nrm[imesh.t_nrm_idx[:, i], :]
|
446 |
+
t = imesh.v_tex[imesh.t_tex_idx[:, i], :]
|
447 |
+
v_displ = v + n * self.scale * util.tex_2d(self.displacement_map, t)
|
448 |
+
|
449 |
+
splat_idx = imesh.t_pos_idx[:, i, None].repeat(1,3)
|
450 |
+
vd.scatter_add_(0, splat_idx, v_displ)
|
451 |
+
vd_n.scatter_add_(0, splat_idx, torch.ones_like(v_displ))
|
452 |
+
|
453 |
+
return Mesh(vd / vd_n, base=imesh)
|
454 |
+
else:
|
455 |
+
vd = torch.zeros([imesh.v_tex.shape[0], 3], dtype=torch.float32, device='cuda')
|
456 |
+
vd_n = torch.zeros([imesh.v_tex.shape[0], 3], dtype=torch.float32, device='cuda')
|
457 |
+
for i in range(0, 3):
|
458 |
+
v = imesh.v_pos[imesh.t_pos_idx[:, i], :]
|
459 |
+
n = imesh.v_nrm[imesh.t_nrm_idx[:, i], :]
|
460 |
+
t = imesh.v_tex[imesh.t_tex_idx[:, i], :]
|
461 |
+
v_displ = v + n * self.scale * util.tex_2d(self.displacement_map, t)
|
462 |
+
|
463 |
+
splat_idx = imesh.t_tex_idx[:, i, None].repeat(1, 3)
|
464 |
+
vd.scatter_add_(0, splat_idx, v_displ)
|
465 |
+
vd_n.scatter_add_(0, splat_idx, torch.ones_like(v_displ))
|
466 |
+
|
467 |
+
return Mesh(vd / vd_n, mesh.t_tex_idx, base=imesh)
|
468 |
+
|
469 |
+
return mesh_op_displace(mesh, displacement_map, scale, keep_connectivity)
|
470 |
+
|
471 |
+
|
472 |
+
######################################################################################
|
473 |
+
# Utilities to merge meshes / materials. No mesh-ops or differentiable stuff here.
|
474 |
+
######################################################################################
|
475 |
+
|
476 |
+
def merge(mesh_a, mesh_b):
|
477 |
+
def _merge_attr_idx(a, b, a_idx, b_idx):
|
478 |
+
if a is None and b is None:
|
479 |
+
return None, None
|
480 |
+
elif a is not None and b is None:
|
481 |
+
return a, a_idx
|
482 |
+
elif a is None and b is not None:
|
483 |
+
return b, b_idx
|
484 |
+
else:
|
485 |
+
return torch.cat((a, b), dim=0), torch.cat((a_idx, b_idx + a.shape[0]), dim=0)
|
486 |
+
|
487 |
+
v_pos, t_pos_idx = _merge_attr_idx(mesh_a.v_pos, mesh_b.v_pos, mesh_a.t_pos_idx, mesh_b.t_pos_idx)
|
488 |
+
v_nrm, t_nrm_idx = _merge_attr_idx(mesh_a.v_nrm, mesh_b.v_nrm, mesh_a.t_nrm_idx, mesh_b.t_nrm_idx)
|
489 |
+
v_tng, t_tng_idx = _merge_attr_idx(mesh_a.v_tng, mesh_b.v_tng, mesh_a.t_tng_idx, mesh_b.t_tng_idx)
|
490 |
+
v_tex, t_tex_idx = _merge_attr_idx(mesh_a.v_tex, mesh_b.v_tex, mesh_a.t_tex_idx, mesh_b.t_tex_idx)
|
491 |
+
|
492 |
+
if mesh_a.v_weights is None and mesh_b.v_weights is None:
|
493 |
+
v_weights, bone_mtx = None, None
|
494 |
+
elif mesh_a.v_weights is not None and mesh_b.v_weights is None:
|
495 |
+
v_weights, bone_mtx = mesh_a.v_weights, mesh_a.bone_mtx
|
496 |
+
elif mesh_a.v_weights is None and mesh_b.v_weights is not None:
|
497 |
+
v_weights, bone_mtx = mesh_b.v_weights, mesh_b.bone_mtx
|
498 |
+
else:
|
499 |
+
if torch.all(mesh_a.bone_mtx == mesh_b.bone_mtx): # TODO: Wanted to test if same pointer
|
500 |
+
bone_mtx = mesh_a.bone_mtx
|
501 |
+
v_weights = torch.cat((mesh_a.v_weights, mesh_b.v_weights), dim=0)
|
502 |
+
else:
|
503 |
+
bone_mtx = torch.cat((mesh_a.bone_mtx, mesh_b.bone_mtx), dim=1) # Frame, Bone, ...
|
504 |
+
|
505 |
+
# Weights need to be increased to account for all bones
|
506 |
+
v_wa = torch.nn.functional.pad(mesh_a.v_weights, [0, mesh_b.v_weights.shape[1]]) #Pad weights_a with shape of weights_b
|
507 |
+
v_wb = torch.nn.functional.pad(mesh_b.v_weights, [mesh_a.v_weights.shape[1], 0]) #Pad weights_b with shape of weights_a
|
508 |
+
v_weights = torch.cat((v_wa, v_wb), dim=0)
|
509 |
+
|
510 |
+
return Mesh(v_pos=v_pos, t_pos_idx=t_pos_idx, v_nrm=v_nrm, t_nrm_idx=t_nrm_idx, v_tng=v_tng, t_tng_idx=t_tng_idx, v_tex=v_tex, t_tex_idx=t_tex_idx, v_weights=v_weights, bone_mtx=bone_mtx, base=mesh_a)
|
nvdiffmodeling/src/obj.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import os
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from . import util
|
14 |
+
from . import texture
|
15 |
+
from . import mesh
|
16 |
+
from . import material
|
17 |
+
|
18 |
+
######################################################################################
|
19 |
+
# Utility functions
|
20 |
+
######################################################################################
|
21 |
+
|
22 |
+
def _write_weights(folder, mesh):
|
23 |
+
if mesh.v_weights is not None:
|
24 |
+
file = os.path.join(folder, 'mesh.weights')
|
25 |
+
np.save(file, mesh.v_weights.detach().cpu().numpy())
|
26 |
+
|
27 |
+
def _write_bones(folder, mesh):
|
28 |
+
if mesh.bone_mtx is not None:
|
29 |
+
file = os.path.join(folder, 'mesh.bones')
|
30 |
+
np.save(file, mesh.bone_mtx.detach().cpu().numpy())
|
31 |
+
|
32 |
+
def _find_mat(materials, name):
|
33 |
+
for mat in materials:
|
34 |
+
if mat['name'] == name:
|
35 |
+
return mat
|
36 |
+
return materials[0] # Materials 0 is the default
|
37 |
+
|
38 |
+
######################################################################################
|
39 |
+
# Create mesh object from objfile
|
40 |
+
######################################################################################
|
41 |
+
|
42 |
+
def load_obj(filename, clear_ks=True, mtl_override=None):
|
43 |
+
obj_path = os.path.dirname(filename)
|
44 |
+
|
45 |
+
# Read entire file
|
46 |
+
with open(filename) as f:
|
47 |
+
lines = f.readlines()
|
48 |
+
|
49 |
+
# Load materials
|
50 |
+
all_materials = [
|
51 |
+
{
|
52 |
+
'name' : '_default_mat',
|
53 |
+
'bsdf' : 'falcor',
|
54 |
+
'kd' : texture.Texture2D(torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device='cuda')),
|
55 |
+
'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'))
|
56 |
+
}
|
57 |
+
]
|
58 |
+
if mtl_override is None:
|
59 |
+
for line in lines:
|
60 |
+
if len(line.split()) == 0:
|
61 |
+
continue
|
62 |
+
if line.split()[0] == 'mtllib':
|
63 |
+
all_materials += material.load_mtl(obj_path + os.path.join(line.split()[1]), clear_ks) # Read in entire material library #obj_path
|
64 |
+
else:
|
65 |
+
all_materials += material.load_mtl(mtl_override)
|
66 |
+
|
67 |
+
# load vertices
|
68 |
+
vertices, texcoords, normals = [], [], []
|
69 |
+
for line in lines:
|
70 |
+
if len(line.split()) == 0:
|
71 |
+
continue
|
72 |
+
|
73 |
+
prefix = line.split()[0].lower()
|
74 |
+
if prefix == 'v':
|
75 |
+
vertices.append([float(v) for v in line.split()[1:]][:3])
|
76 |
+
elif prefix == 'vt':
|
77 |
+
val = [float(v) for v in line.split()[1:]]
|
78 |
+
texcoords.append([val[0], 1.0 - val[1]])
|
79 |
+
elif prefix == 'vn':
|
80 |
+
normals.append([float(v) for v in line.split()[1:]])
|
81 |
+
|
82 |
+
# load faces
|
83 |
+
activeMatIdx = None
|
84 |
+
used_materials = []
|
85 |
+
faces, tfaces, nfaces, mfaces = [], [], [], []
|
86 |
+
for line in lines:
|
87 |
+
if len(line.split()) == 0:
|
88 |
+
continue
|
89 |
+
|
90 |
+
prefix = line.split()[0].lower()
|
91 |
+
if prefix == 'usemtl': # Track used materials
|
92 |
+
mat = _find_mat(all_materials, line.split()[1])
|
93 |
+
if not mat in used_materials:
|
94 |
+
used_materials.append(mat)
|
95 |
+
activeMatIdx = used_materials.index(mat)
|
96 |
+
elif prefix == 'f': # Parse face
|
97 |
+
vs = line.split()[1:]
|
98 |
+
nv = len(vs)
|
99 |
+
vv = vs[0].split('/')
|
100 |
+
v0 = int(vv[0]) - 1
|
101 |
+
if len(vv) > 1:
|
102 |
+
t0 = int(vv[1]) - 1 if vv[1] != "" else -1
|
103 |
+
n0 = int(vv[2]) - 1 if vv[2] != "" else -1
|
104 |
+
else:
|
105 |
+
t0 = -1
|
106 |
+
n0 = -1
|
107 |
+
for i in range(nv - 2): # Triangulate polygons
|
108 |
+
vv = vs[i + 1].split('/')
|
109 |
+
v1 = int(vv[0]) - 1
|
110 |
+
if len(vv) > 1:
|
111 |
+
t1 = int(vv[1]) - 1 if vv[1] != "" else -1
|
112 |
+
n1 = int(vv[2]) - 1 if vv[2] != "" else -1
|
113 |
+
else:
|
114 |
+
t1 = -1
|
115 |
+
n1 = -1
|
116 |
+
vv = vs[i + 2].split('/')
|
117 |
+
v2 = int(vv[0]) - 1
|
118 |
+
if len(vv) > 1:
|
119 |
+
t2 = int(vv[1]) - 1 if vv[1] != "" else -1
|
120 |
+
n2 = int(vv[2]) - 1 if vv[2] != "" else -1
|
121 |
+
else:
|
122 |
+
t2 = -1
|
123 |
+
n2 = -1
|
124 |
+
mfaces.append(activeMatIdx)
|
125 |
+
faces.append([v0, v1, v2])
|
126 |
+
tfaces.append([t0, t1, t2])
|
127 |
+
nfaces.append([n0, n1, n2])
|
128 |
+
assert len(tfaces) == len(faces) and len(nfaces) == len (faces)
|
129 |
+
|
130 |
+
# Create an "uber" material by combining all textures into a larger texture
|
131 |
+
if len(used_materials) > 1:
|
132 |
+
uber_material, texcoords, tfaces = material.merge_materials(used_materials, texcoords, tfaces, mfaces)
|
133 |
+
elif len(used_materials) == 1:
|
134 |
+
uber_material = used_materials[0]
|
135 |
+
else:
|
136 |
+
uber_material = None
|
137 |
+
|
138 |
+
vertices = torch.tensor(vertices, dtype=torch.float32, device='cuda')
|
139 |
+
texcoords = torch.tensor(texcoords, dtype=torch.float32, device='cuda') if len(texcoords) > 0 else None
|
140 |
+
normals = torch.tensor(normals, dtype=torch.float32, device='cuda') if len(normals) > 0 else None
|
141 |
+
|
142 |
+
faces = torch.tensor(faces, dtype=torch.int64, device='cuda')
|
143 |
+
tfaces = torch.tensor(tfaces, dtype=torch.int64, device='cuda') if texcoords is not None else None
|
144 |
+
nfaces = torch.tensor(nfaces, dtype=torch.int64, device='cuda') if normals is not None else None
|
145 |
+
|
146 |
+
# Read weights and bones if available
|
147 |
+
try:
|
148 |
+
v_weights = torch.tensor(np.load(os.path.splitext(filename)[0] + ".weights.npy"), dtype=torch.float32, device='cuda')
|
149 |
+
bone_mtx = torch.tensor(np.load(os.path.splitext(filename)[0] + ".bones.npy"), dtype=torch.float32, device='cuda')
|
150 |
+
except:
|
151 |
+
v_weights, bone_mtx = None, None
|
152 |
+
|
153 |
+
return mesh.Mesh(vertices, faces, normals, nfaces, texcoords, tfaces, v_weights=v_weights, bone_mtx=bone_mtx, material=uber_material)
|
154 |
+
|
155 |
+
######################################################################################
|
156 |
+
# Save mesh object to objfile
|
157 |
+
######################################################################################
|
158 |
+
|
159 |
+
def write_obj(folder, mesh, verbose=True):
|
160 |
+
obj_file = os.path.join(folder, 'mesh.obj')
|
161 |
+
if verbose:
|
162 |
+
print("Writing mesh: ", obj_file)
|
163 |
+
with open(obj_file, "w") as f:
|
164 |
+
f.write("mtllib mesh.mtl\n")
|
165 |
+
f.write("g default\n")
|
166 |
+
|
167 |
+
v_pos = mesh.v_pos.detach().cpu().numpy() if mesh.v_pos is not None else None
|
168 |
+
v_nrm = mesh.v_nrm.detach().cpu().numpy() if mesh.v_nrm is not None else None
|
169 |
+
v_tex = mesh.v_tex.detach().cpu().numpy() if mesh.v_tex is not None else None
|
170 |
+
|
171 |
+
t_pos_idx = mesh.t_pos_idx.detach().cpu().numpy() if mesh.t_pos_idx is not None else None
|
172 |
+
t_nrm_idx = mesh.t_nrm_idx.detach().cpu().numpy() if mesh.t_nrm_idx is not None else None
|
173 |
+
t_tex_idx = mesh.t_tex_idx.detach().cpu().numpy() if mesh.t_tex_idx is not None else None
|
174 |
+
if verbose:
|
175 |
+
print(" writing %d vertices" % len(v_pos))
|
176 |
+
for v in v_pos:
|
177 |
+
f.write('v {} {} {} \n'.format(v[0], v[1], v[2]))
|
178 |
+
|
179 |
+
if v_tex is not None:
|
180 |
+
if verbose:
|
181 |
+
print(" writing %d texcoords" % len(v_tex))
|
182 |
+
assert(len(t_pos_idx) == len(t_tex_idx))
|
183 |
+
for v in v_tex:
|
184 |
+
f.write('vt {} {} \n'.format(v[0], 1.0 - v[1]))
|
185 |
+
|
186 |
+
if v_nrm is not None:
|
187 |
+
if verbose:
|
188 |
+
print(" writing %d normals" % len(v_nrm))
|
189 |
+
assert(len(t_pos_idx) == len(t_nrm_idx))
|
190 |
+
for v in v_nrm:
|
191 |
+
f.write('vn {} {} {}\n'.format(v[0], v[1], v[2]))
|
192 |
+
|
193 |
+
# faces
|
194 |
+
f.write("s 1 \n")
|
195 |
+
f.write("g pMesh1\n")
|
196 |
+
f.write("usemtl defaultMat\n")
|
197 |
+
|
198 |
+
# Write faces
|
199 |
+
if verbose:
|
200 |
+
print(" writing %d faces" % len(t_pos_idx))
|
201 |
+
for i in range(len(t_pos_idx)):
|
202 |
+
f.write("f ")
|
203 |
+
for j in range(3):
|
204 |
+
f.write(' %s/%s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1), '' if v_nrm is None else str(t_nrm_idx[i][j]+1)))
|
205 |
+
f.write("\n")
|
206 |
+
|
207 |
+
mtl_file = os.path.join(folder, 'mesh.mtl')
|
208 |
+
if verbose:
|
209 |
+
print("Writing material: ", mtl_file)
|
210 |
+
material.save_mtl(mtl_file, mesh.material)
|
211 |
+
|
212 |
+
_write_weights(folder, mesh)
|
213 |
+
_write_bones(folder, mesh)
|
214 |
+
if verbose:
|
215 |
+
print("Done exporting mesh")
|
nvdiffmodeling/src/regularizer.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import os
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from . import util
|
14 |
+
from . import texture
|
15 |
+
|
16 |
+
######################################################################################
|
17 |
+
# Computes the avergage edge length of a mesh.
|
18 |
+
# Rough estimate of the tessellation of a mesh. Can be used e.g. to clamp gradients
|
19 |
+
######################################################################################
|
20 |
+
def avg_edge_length(opt_mesh):
|
21 |
+
with torch.no_grad():
|
22 |
+
opt_mesh = opt_mesh.eval()
|
23 |
+
nVerts = opt_mesh.v_pos.shape[0]
|
24 |
+
t_pos_idx = opt_mesh.t_pos_idx.detach().cpu().numpy()
|
25 |
+
|
26 |
+
# Find unique edges
|
27 |
+
ix_i = []
|
28 |
+
ix_j = []
|
29 |
+
edge_verts = {}
|
30 |
+
for tri in t_pos_idx:
|
31 |
+
for (i0, i1) in [(tri[0], tri[1]), (tri[1], tri[2]), (tri[2], tri[0])]:
|
32 |
+
if (i1, i0) not in edge_verts.keys():
|
33 |
+
edge_verts[(i0, i1)] = True
|
34 |
+
ix_i += [i0]
|
35 |
+
ix_j += [i1]
|
36 |
+
|
37 |
+
# Setup torch tensors
|
38 |
+
ix_i = torch.tensor(ix_i, dtype=torch.int64, device='cuda')
|
39 |
+
ix_j = torch.tensor(ix_j, dtype=torch.int64, device='cuda')
|
40 |
+
|
41 |
+
# Gather edge vertex pairs
|
42 |
+
x_i = opt_mesh.v_pos[ix_i, :]
|
43 |
+
x_j = opt_mesh.v_pos[ix_j, :]
|
44 |
+
|
45 |
+
# Compute edge length
|
46 |
+
term = torch.sqrt((x_j - x_i)**2)
|
47 |
+
|
48 |
+
# Compute avg edge length
|
49 |
+
return (torch.sum(term) / len(x_i)).item()
|
50 |
+
|
51 |
+
######################################################################################
|
52 |
+
# Edge length regularizer
|
53 |
+
######################################################################################
|
54 |
+
def edge_length_regularizer(mesh):
|
55 |
+
class mesh_op_edge_length_regularizer:
|
56 |
+
def __init__(self, mesh):
|
57 |
+
self.mesh = mesh
|
58 |
+
|
59 |
+
mesh = mesh.eval()
|
60 |
+
nVerts = mesh.v_pos.shape[0]
|
61 |
+
t_pos_idx = mesh.t_pos_idx.detach().cpu().numpy()
|
62 |
+
|
63 |
+
# Find unique edges
|
64 |
+
ix_i = []
|
65 |
+
ix_j = []
|
66 |
+
edge_verts = {}
|
67 |
+
for tri in t_pos_idx:
|
68 |
+
for (i0, i1) in [(tri[0], tri[1]), (tri[1], tri[2]), (tri[2], tri[0])]:
|
69 |
+
if (i1, i0) not in edge_verts.keys():
|
70 |
+
edge_verts[(i0, i1)] = True
|
71 |
+
ix_i += [i0]
|
72 |
+
ix_j += [i1]
|
73 |
+
|
74 |
+
# Setup torch tensors
|
75 |
+
self.ix_i = torch.tensor(ix_i, dtype=torch.int64, device='cuda')
|
76 |
+
self.ix_j = torch.tensor(ix_j, dtype=torch.int64, device='cuda')
|
77 |
+
|
78 |
+
def eval(self, params={}):
|
79 |
+
mesh = self.mesh.eval(params)
|
80 |
+
|
81 |
+
# Gather edge vertex pairs
|
82 |
+
x_i = mesh.v_pos[self.ix_i, :]
|
83 |
+
x_j = mesh.v_pos[self.ix_j, :]
|
84 |
+
|
85 |
+
# Compute edge length
|
86 |
+
term = torch.sqrt((x_j - x_i)**2 + 1e-20)
|
87 |
+
|
88 |
+
# Compute avg edge length
|
89 |
+
return torch.var(term)
|
90 |
+
|
91 |
+
return mesh_op_edge_length_regularizer(mesh)
|
92 |
+
|
93 |
+
######################################################################################
|
94 |
+
# Laplacian regularization using umbrella operator (Fujiwara / Desbrun).
|
95 |
+
# https://mgarland.org/class/geom04/material/smoothing.pdf
|
96 |
+
######################################################################################
|
97 |
+
def laplace_regularizer_const(opt_mesh, base_mesh=None):
|
98 |
+
class mesh_op_laplace_regularizer_const:
|
99 |
+
def __init__(self, opt_mesh, base_mesh):
|
100 |
+
self.inputs = [opt_mesh, base_mesh]
|
101 |
+
|
102 |
+
opt_mesh = opt_mesh.eval()
|
103 |
+
self.nVerts = opt_mesh.v_pos.shape[0]
|
104 |
+
t_pos_idx = opt_mesh.t_pos_idx.detach().cpu().numpy()
|
105 |
+
|
106 |
+
# Build vertex neighbor rings
|
107 |
+
vtx_n = [[] for _ in range(self.nVerts)]
|
108 |
+
for tri in t_pos_idx:
|
109 |
+
for (i0, i1) in [(tri[0], tri[1]), (tri[1], tri[2]), (tri[2], tri[0])]:
|
110 |
+
vtx_n[i0].append(i1)
|
111 |
+
|
112 |
+
# Collect index/weight pairs to compute each Laplacian vector for each vertex.
|
113 |
+
# Similar notation to https://mgarland.org/class/geom04/material/smoothing.pdf
|
114 |
+
ix_j, ix_i, w_ij = [], [], []
|
115 |
+
for i in range(self.nVerts):
|
116 |
+
m = len(vtx_n[i])
|
117 |
+
ix_i += [i] * m
|
118 |
+
ix_j += vtx_n[i]
|
119 |
+
w_ij += [1.0 / m] * m
|
120 |
+
|
121 |
+
# Setup torch tensors
|
122 |
+
self.ix_i = torch.tensor(ix_i, dtype=torch.int64, device='cuda')
|
123 |
+
self.ix_j = torch.tensor(ix_j, dtype=torch.int64, device='cuda')
|
124 |
+
self.w_ij = torch.tensor(w_ij, dtype=torch.float32, device='cuda')[:, None]
|
125 |
+
|
126 |
+
def eval(self, params={}):
|
127 |
+
opt_mesh = self.inputs[0].eval(params)
|
128 |
+
base_mesh = self.inputs[1].eval(params) if self.inputs[1] is not None else None
|
129 |
+
|
130 |
+
# differences or absolute version (see paper)
|
131 |
+
if base_mesh is not None:
|
132 |
+
v_pos = opt_mesh.v_pos - base_mesh.v_pos
|
133 |
+
else:
|
134 |
+
v_pos = opt_mesh.v_pos
|
135 |
+
|
136 |
+
# Gather edge vertex pairs
|
137 |
+
x_i = v_pos[self.ix_i, :]
|
138 |
+
x_j = v_pos[self.ix_j, :]
|
139 |
+
|
140 |
+
# Compute Laplacian differences: (x_j - x_i) * w_ij
|
141 |
+
term = (x_j - x_i) * self.w_ij
|
142 |
+
|
143 |
+
# Sum everyhing
|
144 |
+
term = util.segment_sum(term, self.ix_i)
|
145 |
+
|
146 |
+
return torch.mean(term**2)
|
147 |
+
|
148 |
+
return mesh_op_laplace_regularizer_const(opt_mesh, base_mesh)
|
149 |
+
|
150 |
+
######################################################################################
|
151 |
+
# Curvature based regularizer
|
152 |
+
######################################################################################
|
153 |
+
def face_normal_regularizer(opt_mesh):
|
154 |
+
class mesh_op_face_normal_regularizer:
|
155 |
+
def __init__(self, opt_mesh):
|
156 |
+
self.input = opt_mesh
|
157 |
+
|
158 |
+
imesh = opt_mesh.eval()
|
159 |
+
self.nVerts = imesh.v_pos.shape[0]
|
160 |
+
t_pos_idx = imesh.t_pos_idx.detach().cpu().numpy()
|
161 |
+
|
162 |
+
# Generate edge lists
|
163 |
+
edge_tris = {}
|
164 |
+
for tri_idx, tri in enumerate(t_pos_idx):
|
165 |
+
for (i0, i1) in [(tri[0], tri[1]), (tri[1], tri[2]), (tri[2], tri[0])]:
|
166 |
+
if (i1, i0) in edge_tris.keys():
|
167 |
+
edge_tris[(i1, i0)] += [tri_idx]
|
168 |
+
else:
|
169 |
+
edge_tris[(i0, i1)] = [tri_idx]
|
170 |
+
|
171 |
+
# Get all good edges with 2 incident triangles
|
172 |
+
shared_edge_idx = []
|
173 |
+
for edge in edge_tris.values():
|
174 |
+
if len(edge) == 2:
|
175 |
+
shared_edge_idx += [edge]
|
176 |
+
self.edge_tri_idx = torch.tensor(shared_edge_idx, dtype=torch.int64, device='cuda')
|
177 |
+
|
178 |
+
def eval(self, params={}):
|
179 |
+
imesh = self.input.eval(params)
|
180 |
+
|
181 |
+
# Compute face normals
|
182 |
+
v0 = imesh.v_pos[imesh.t_pos_idx[:, 0], :]
|
183 |
+
v1 = imesh.v_pos[imesh.t_pos_idx[:, 1], :]
|
184 |
+
v2 = imesh.v_pos[imesh.t_pos_idx[:, 2], :]
|
185 |
+
face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0))
|
186 |
+
|
187 |
+
# Fetch normals for both faces sharind an edge
|
188 |
+
n0 = face_normals[self.edge_tri_idx[:, 0], :]
|
189 |
+
n1 = face_normals[self.edge_tri_idx[:, 1], :]
|
190 |
+
|
191 |
+
# Compute error metric based on normal difference
|
192 |
+
term = torch.clamp(util.dot(n0, n1), min=-1.0, max=1.0)
|
193 |
+
term = (1.0 - term) * 0.5
|
194 |
+
|
195 |
+
return torch.mean(torch.abs(term))
|
196 |
+
|
197 |
+
return mesh_op_face_normal_regularizer(opt_mesh)
|
nvdiffmodeling/src/render.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import nvdiffrast.torch as dr
|
13 |
+
|
14 |
+
from . import util
|
15 |
+
from . import mesh
|
16 |
+
from . import renderutils as ru
|
17 |
+
|
18 |
+
# ==============================================================================================
|
19 |
+
# Helper functions
|
20 |
+
# ==============================================================================================
|
21 |
+
def interpolate(attr, rast, attr_idx, rast_db=None):
|
22 |
+
return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all')
|
23 |
+
|
24 |
+
# ==============================================================================================
|
25 |
+
# pixel shader
|
26 |
+
# ==============================================================================================
|
27 |
+
def shade(
|
28 |
+
gb_pos,
|
29 |
+
gb_geometric_normal,
|
30 |
+
gb_normal,
|
31 |
+
gb_tangent,
|
32 |
+
gb_texc,
|
33 |
+
gb_texc_deriv,
|
34 |
+
view_pos,
|
35 |
+
light_pos,
|
36 |
+
light_power,
|
37 |
+
material,
|
38 |
+
min_roughness
|
39 |
+
):
|
40 |
+
|
41 |
+
################################################################################
|
42 |
+
# Texture lookups
|
43 |
+
################################################################################
|
44 |
+
|
45 |
+
kd = material['kd'].sample(gb_texc, gb_texc_deriv)
|
46 |
+
ks = material['ks'].sample(gb_texc, gb_texc_deriv)[..., 0:3] # skip alpha
|
47 |
+
perturbed_nrm = None
|
48 |
+
if 'normal' in material:
|
49 |
+
perturbed_nrm = material['normal'].sample(gb_texc, gb_texc_deriv)
|
50 |
+
|
51 |
+
gb_normal = ru.prepare_shading_normal(gb_pos, view_pos, perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True)
|
52 |
+
|
53 |
+
# Separate kd into alpha and color, default alpha = 1
|
54 |
+
alpha = kd[..., 3:4] if kd.shape[-1] == 4 else torch.ones_like(kd[..., 0:1])
|
55 |
+
kd = kd[..., 0:3]
|
56 |
+
|
57 |
+
################################################################################
|
58 |
+
# Evaluate BSDF
|
59 |
+
################################################################################
|
60 |
+
|
61 |
+
assert 'bsdf' in material, "Material must specify a BSDF type"
|
62 |
+
if material['bsdf'] == 'pbr':
|
63 |
+
shaded_col = ru.pbr_bsdf(kd, ks, gb_pos, gb_normal, view_pos, light_pos, min_roughness) * light_power
|
64 |
+
elif material['bsdf'] == 'diffuse':
|
65 |
+
shaded_col = kd * ru.lambert(gb_normal, util.safe_normalize(light_pos - gb_pos)) * light_power
|
66 |
+
elif material['bsdf'] == 'normal':
|
67 |
+
shaded_col = (gb_normal + 1.0)*0.5
|
68 |
+
elif material['bsdf'] == 'tangent':
|
69 |
+
shaded_col = (gb_tangent + 1.0)*0.5
|
70 |
+
else:
|
71 |
+
assert False, "Invalid BSDF '%s'" % material['bsdf']
|
72 |
+
|
73 |
+
out = torch.cat((shaded_col, alpha), dim=-1)
|
74 |
+
|
75 |
+
return out
|
76 |
+
|
77 |
+
# ==============================================================================================
|
78 |
+
# Render a depth slice of the mesh (scene), some limitations:
|
79 |
+
# - Single mesh
|
80 |
+
# - Single light
|
81 |
+
# - Single material
|
82 |
+
# ==============================================================================================
|
83 |
+
def render_layer(
|
84 |
+
rast,
|
85 |
+
rast_deriv,
|
86 |
+
mesh,
|
87 |
+
view_pos,
|
88 |
+
light_pos,
|
89 |
+
light_power,
|
90 |
+
resolution,
|
91 |
+
min_roughness,
|
92 |
+
spp,
|
93 |
+
msaa
|
94 |
+
):
|
95 |
+
|
96 |
+
full_res = resolution*spp
|
97 |
+
|
98 |
+
################################################################################
|
99 |
+
# Rasterize
|
100 |
+
################################################################################
|
101 |
+
|
102 |
+
# Scale down to shading resolution when MSAA is enabled, otherwise shade at full resolution
|
103 |
+
if spp > 1 and msaa:
|
104 |
+
rast_out_s = util.scale_img_nhwc(rast, [resolution, resolution], mag='nearest', min='nearest')
|
105 |
+
rast_out_deriv_s = util.scale_img_nhwc(rast_deriv, [resolution, resolution], mag='nearest', min='nearest') * spp
|
106 |
+
else:
|
107 |
+
rast_out_s = rast
|
108 |
+
rast_out_deriv_s = rast_deriv
|
109 |
+
|
110 |
+
################################################################################
|
111 |
+
# Interpolate attributes
|
112 |
+
################################################################################
|
113 |
+
|
114 |
+
# Interpolate world space position
|
115 |
+
gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast_out_s, mesh.t_pos_idx.int())
|
116 |
+
|
117 |
+
# Compute geometric normals. We need those because of bent normals trick (for bump mapping)
|
118 |
+
v0 = mesh.v_pos[mesh.t_pos_idx[:, 0], :]
|
119 |
+
v1 = mesh.v_pos[mesh.t_pos_idx[:, 1], :]
|
120 |
+
v2 = mesh.v_pos[mesh.t_pos_idx[:, 2], :]
|
121 |
+
face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0))
|
122 |
+
face_normal_indices = (torch.arange(0, face_normals.shape[0], dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3)
|
123 |
+
gb_geometric_normal, _ = interpolate(face_normals[None, ...], rast_out_s, face_normal_indices.int())
|
124 |
+
|
125 |
+
# Compute tangent space
|
126 |
+
assert mesh.v_nrm is not None and mesh.v_tng is not None
|
127 |
+
gb_normal, _ = interpolate(mesh.v_nrm[None, ...], rast_out_s, mesh.t_nrm_idx.int())
|
128 |
+
gb_tangent, _ = interpolate(mesh.v_tng[None, ...], rast_out_s, mesh.t_tng_idx.int()) # Interpolate tangents
|
129 |
+
|
130 |
+
# Texure coordinate
|
131 |
+
assert mesh.v_tex is not None
|
132 |
+
gb_texc, gb_texc_deriv = interpolate(mesh.v_tex[None, ...], rast_out_s, mesh.t_tex_idx.int(), rast_db=rast_out_deriv_s)
|
133 |
+
|
134 |
+
################################################################################
|
135 |
+
# Shade
|
136 |
+
################################################################################
|
137 |
+
|
138 |
+
color = shade(gb_pos, gb_geometric_normal, gb_normal, gb_tangent, gb_texc, gb_texc_deriv,
|
139 |
+
view_pos, light_pos, light_power, mesh.material, min_roughness)
|
140 |
+
|
141 |
+
################################################################################
|
142 |
+
# Prepare output
|
143 |
+
################################################################################
|
144 |
+
|
145 |
+
# Scale back up to visibility resolution if using MSAA
|
146 |
+
if spp > 1 and msaa:
|
147 |
+
color = util.scale_img_nhwc(color, [full_res, full_res], mag='nearest', min='nearest')
|
148 |
+
|
149 |
+
# Return color & raster output for peeling
|
150 |
+
return color
|
151 |
+
|
152 |
+
|
153 |
+
# ==============================================================================================
|
154 |
+
# Render a depth peeled mesh (scene), some limitations:
|
155 |
+
# - Single mesh
|
156 |
+
# - Single light
|
157 |
+
# - Single material
|
158 |
+
# ==============================================================================================
|
159 |
+
def render_mesh(
|
160 |
+
ctx,
|
161 |
+
mesh,
|
162 |
+
mtx_in,
|
163 |
+
view_pos,
|
164 |
+
light_pos,
|
165 |
+
light_power,
|
166 |
+
resolution,
|
167 |
+
spp = 1,
|
168 |
+
num_layers = 1,
|
169 |
+
msaa = False,
|
170 |
+
background = None,
|
171 |
+
antialias = True,
|
172 |
+
min_roughness = 0.08,
|
173 |
+
return_rast_map = False,
|
174 |
+
):
|
175 |
+
assert not (return_rast_map and num_layers > 1)
|
176 |
+
|
177 |
+
def prepare_input_vector(x):
|
178 |
+
x = torch.tensor(x, dtype=torch.float32, device='cuda') if not torch.is_tensor(x) else x
|
179 |
+
return x[:, None, None, :] if len(x.shape) == 2 else x
|
180 |
+
|
181 |
+
full_res = resolution*spp
|
182 |
+
|
183 |
+
# Convert numpy arrays to torch tensors
|
184 |
+
mtx_in = torch.tensor(mtx_in, dtype=torch.float32, device='cuda') if not torch.is_tensor(mtx_in) else mtx_in
|
185 |
+
light_pos = prepare_input_vector(light_pos)
|
186 |
+
light_power = prepare_input_vector(light_power)
|
187 |
+
view_pos = prepare_input_vector(view_pos)
|
188 |
+
|
189 |
+
# clip space transform
|
190 |
+
v_pos_clip = ru.xfm_points(mesh.v_pos[None, ...], mtx_in)
|
191 |
+
|
192 |
+
# Render all layers front-to-back
|
193 |
+
layers = []
|
194 |
+
with dr.DepthPeeler(ctx, v_pos_clip, mesh.t_pos_idx.int(), [resolution*spp, resolution*spp]) as peeler:
|
195 |
+
for _ in range(num_layers):
|
196 |
+
rast, db = peeler.rasterize_next_layer()
|
197 |
+
layers += [(render_layer(rast, db, mesh, view_pos, light_pos, light_power, resolution, min_roughness, spp, msaa), rast)]
|
198 |
+
|
199 |
+
if return_rast_map:
|
200 |
+
return rast.detach()
|
201 |
+
|
202 |
+
# Clear to background layer
|
203 |
+
if background is not None:
|
204 |
+
assert background.shape[1] == resolution and background.shape[2] == resolution
|
205 |
+
if spp > 1:
|
206 |
+
background = util.scale_img_nhwc(background, [full_res, full_res], mag='nearest', min='nearest')
|
207 |
+
accum_col = background
|
208 |
+
else:
|
209 |
+
accum_col = torch.zeros(size=(1, full_res, full_res, 3), dtype=torch.float32, device='cuda')
|
210 |
+
|
211 |
+
# Composite BACK-TO-FRONT
|
212 |
+
for color, rast in reversed(layers):
|
213 |
+
alpha = (rast[..., -1:] > 0) * color[..., 3:4]
|
214 |
+
accum_col = torch.lerp(accum_col, color[..., 0:3], alpha)
|
215 |
+
if antialias:
|
216 |
+
accum_col = dr.antialias(accum_col.contiguous(), rast, v_pos_clip, mesh.t_pos_idx.int()) # TODO: need to support bfloat16
|
217 |
+
|
218 |
+
# Downscale to framebuffer resolution. Use avg pooling
|
219 |
+
out = util.avg_pool_nhwc(accum_col, spp) if spp > 1 else accum_col
|
220 |
+
|
221 |
+
return out
|
222 |
+
|
223 |
+
|
nvdiffmodeling/src/renderutils/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
from .ops import xfm_points, xfm_vectors, image_loss, prepare_shading_normal, lambert, pbr_specular, pbr_bsdf, _fresnel_shlick, _ndf_ggx, _lambda_ggx, _masking_smith
|
10 |
+
__all__ = ["xfm_vectors", "xfm_points", "image_loss", "prepare_shading_normal", "lambert", "pbr_specular", "pbr_bsdf", "_fresnel_shlick", "_ndf_ggx", "_lambda_ggx", "_masking_smith", ]
|
nvdiffmodeling/src/renderutils/bsdf.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import math
|
10 |
+
import torch
|
11 |
+
|
12 |
+
NORMAL_THRESHOLD = 0.1
|
13 |
+
|
14 |
+
################################################################################
|
15 |
+
# Vector utility functions
|
16 |
+
################################################################################
|
17 |
+
|
18 |
+
def _dot(x, y):
|
19 |
+
return torch.sum(x*y, -1, keepdim=True)
|
20 |
+
|
21 |
+
def _reflect(x, n):
|
22 |
+
return 2*_dot(x, n)*n - x
|
23 |
+
|
24 |
+
def _safe_normalize(x):
|
25 |
+
return torch.nn.functional.normalize(x, dim = -1)
|
26 |
+
|
27 |
+
def _bend_normal(view_vec, smooth_nrm, geom_nrm, two_sided_shading):
|
28 |
+
# Swap normal direction for backfacing surfaces
|
29 |
+
if two_sided_shading:
|
30 |
+
smooth_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, smooth_nrm, -smooth_nrm)
|
31 |
+
geom_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, geom_nrm, -geom_nrm)
|
32 |
+
|
33 |
+
t = torch.clamp(_dot(view_vec, smooth_nrm) / NORMAL_THRESHOLD, min=0, max=1)
|
34 |
+
return torch.lerp(geom_nrm, smooth_nrm, t)
|
35 |
+
|
36 |
+
|
37 |
+
def _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl):
|
38 |
+
smooth_bitang = _safe_normalize(torch.cross(smooth_tng, smooth_nrm))
|
39 |
+
if opengl:
|
40 |
+
shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] - smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0)
|
41 |
+
else:
|
42 |
+
shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] + smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0)
|
43 |
+
return _safe_normalize(shading_nrm)
|
44 |
+
|
45 |
+
def bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl):
|
46 |
+
smooth_nrm = _safe_normalize(smooth_nrm)
|
47 |
+
smooth_tng = _safe_normalize(smooth_tng)
|
48 |
+
view_vec = _safe_normalize(view_pos - pos)
|
49 |
+
shading_nrm = _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl)
|
50 |
+
return _bend_normal(view_vec, shading_nrm, geom_nrm, two_sided_shading)
|
51 |
+
|
52 |
+
################################################################################
|
53 |
+
# Simple lambertian diffuse BSDF
|
54 |
+
################################################################################
|
55 |
+
|
56 |
+
def bsdf_lambert(nrm, wi):
|
57 |
+
return torch.clamp(_dot(nrm, wi), min=0.0) / math.pi
|
58 |
+
|
59 |
+
################################################################################
|
60 |
+
# Phong specular, loosely based on mitsuba implementation
|
61 |
+
################################################################################
|
62 |
+
|
63 |
+
def bsdf_phong(nrm, wo, wi, N):
|
64 |
+
dp_r = torch.clamp(_dot(_reflect(wo, nrm), wi), min=0.0, max=1.0)
|
65 |
+
dp_l = torch.clamp(_dot(nrm, wi), min=0.0, max=1.0)
|
66 |
+
return (dp_r ** N) * dp_l * (N + 2) / (2 * math.pi)
|
67 |
+
|
68 |
+
################################################################################
|
69 |
+
# PBR's implementation of GGX specular
|
70 |
+
################################################################################
|
71 |
+
|
72 |
+
specular_epsilon = 1e-4
|
73 |
+
|
74 |
+
def bsdf_fresnel_shlick(f0, f90, cosTheta):
|
75 |
+
_cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
|
76 |
+
return f0 + (f90 - f0) * (1.0 - _cosTheta) ** 5.0
|
77 |
+
|
78 |
+
def bsdf_ndf_ggx(alphaSqr, cosTheta):
|
79 |
+
_cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
|
80 |
+
d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1
|
81 |
+
return alphaSqr / (d * d * math.pi)
|
82 |
+
|
83 |
+
def bsdf_lambda_ggx(alphaSqr, cosTheta):
|
84 |
+
_cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
|
85 |
+
cosThetaSqr = _cosTheta * _cosTheta
|
86 |
+
tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr
|
87 |
+
res = 0.5 * (torch.sqrt(1 + alphaSqr * tanThetaSqr) - 1.0)
|
88 |
+
return res
|
89 |
+
|
90 |
+
def bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO):
|
91 |
+
lambdaI = bsdf_lambda_ggx(alphaSqr, cosThetaI)
|
92 |
+
lambdaO = bsdf_lambda_ggx(alphaSqr, cosThetaO)
|
93 |
+
return 1 / (1 + lambdaI + lambdaO)
|
94 |
+
|
95 |
+
def bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08):
|
96 |
+
_alpha = torch.clamp(alpha, min=min_roughness*min_roughness, max=1.0)
|
97 |
+
alphaSqr = _alpha * _alpha
|
98 |
+
|
99 |
+
h = _safe_normalize(wo + wi)
|
100 |
+
woDotN = _dot(wo, nrm)
|
101 |
+
wiDotN = _dot(wi, nrm)
|
102 |
+
woDotH = _dot(wo, h)
|
103 |
+
nDotH = _dot(nrm, h)
|
104 |
+
|
105 |
+
D = bsdf_ndf_ggx(alphaSqr, nDotH)
|
106 |
+
G = bsdf_masking_smith_ggx_correlated(alphaSqr, woDotN, wiDotN)
|
107 |
+
F = bsdf_fresnel_shlick(col, 1, woDotH)
|
108 |
+
|
109 |
+
w = F * D * G * 0.25 / torch.clamp(woDotN, min=specular_epsilon)
|
110 |
+
|
111 |
+
frontfacing = (woDotN > specular_epsilon) & (wiDotN > specular_epsilon)
|
112 |
+
return torch.where(frontfacing, w, torch.zeros_like(w))
|
113 |
+
|
114 |
+
def bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness=0.08):
|
115 |
+
wo = _safe_normalize(view_pos - pos)
|
116 |
+
wi = _safe_normalize(light_pos - pos)
|
117 |
+
|
118 |
+
spec_str = arm[..., 0:1] # x component
|
119 |
+
roughness = arm[..., 1:2] # y component
|
120 |
+
metallic = arm[..., 2:3] # z component
|
121 |
+
ks = (0.04 * (1.0 - metallic) + kd * metallic) * (1 - spec_str)
|
122 |
+
kd = kd * (1.0 - metallic)
|
123 |
+
|
124 |
+
diffuse = kd * bsdf_lambert(nrm, wi)
|
125 |
+
specular = bsdf_pbr_specular(ks, nrm, wo, wi, roughness*roughness, min_roughness=min_roughness)
|
126 |
+
return diffuse + specular
|
nvdiffmodeling/src/renderutils/c_src/bsdf.cu
ADDED
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include "common.h"
|
10 |
+
#include "bsdf.h"
|
11 |
+
|
12 |
+
#define SPECULAR_EPSILON 1e-4f
|
13 |
+
|
14 |
+
//------------------------------------------------------------------------
|
15 |
+
// Lambert functions
|
16 |
+
|
17 |
+
__device__ float fwdLambert(const vec3f nrm, const vec3f wi)
|
18 |
+
{
|
19 |
+
return max(dot(nrm, wi) / M_PI, 0.0f);
|
20 |
+
}
|
21 |
+
|
22 |
+
__device__ void bwdLambert(const vec3f nrm, const vec3f wi, vec3f& d_nrm, vec3f& d_wi, const float d_out)
|
23 |
+
{
|
24 |
+
if (dot(nrm, wi) > 0.0f)
|
25 |
+
bwdDot(nrm, wi, d_nrm, d_wi, d_out / M_PI);
|
26 |
+
}
|
27 |
+
|
28 |
+
//------------------------------------------------------------------------
|
29 |
+
// Fresnel Schlick
|
30 |
+
|
31 |
+
__device__ vec3f fwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta)
|
32 |
+
{
|
33 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
34 |
+
float scale = powf(1.0f - _cosTheta, 5.0f);
|
35 |
+
return f0 * (1.0f - scale) + f90 * scale;
|
36 |
+
}
|
37 |
+
|
38 |
+
__device__ void bwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta, vec3f& d_f0, vec3f& d_f90, float& d_cosTheta, const vec3f d_out)
|
39 |
+
{
|
40 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
41 |
+
float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);
|
42 |
+
d_f0 += d_out * (1.0 - scale);
|
43 |
+
d_f90 += d_out * scale;
|
44 |
+
if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
|
45 |
+
{
|
46 |
+
d_cosTheta += sum(d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f));
|
47 |
+
}
|
48 |
+
}
|
49 |
+
|
50 |
+
//------------------------------------------------------------------------
|
51 |
+
// Ndf GGX
|
52 |
+
|
53 |
+
__device__ float fwdNdfGGX(const float alphaSqr, const float cosTheta)
|
54 |
+
{
|
55 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
56 |
+
float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f;
|
57 |
+
return alphaSqr / (d * d * M_PI);
|
58 |
+
}
|
59 |
+
|
60 |
+
__device__ void bwdNdfGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)
|
61 |
+
{
|
62 |
+
// Torch only back propagates if clamp doesn't trigger
|
63 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
64 |
+
float cosThetaSqr = _cosTheta * _cosTheta;
|
65 |
+
d_alphaSqr += d_out * (1.0f - (alphaSqr + 1.0f) * cosThetaSqr) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));
|
66 |
+
if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
|
67 |
+
{
|
68 |
+
d_cosTheta += d_out * -(4.0f * (alphaSqr - 1.0f) * alphaSqr * cosTheta) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));
|
69 |
+
}
|
70 |
+
}
|
71 |
+
|
72 |
+
//------------------------------------------------------------------------
|
73 |
+
// Lambda GGX
|
74 |
+
|
75 |
+
__device__ float fwdLambdaGGX(const float alphaSqr, const float cosTheta)
|
76 |
+
{
|
77 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
78 |
+
float cosThetaSqr = _cosTheta * _cosTheta;
|
79 |
+
float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;
|
80 |
+
float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);
|
81 |
+
return res;
|
82 |
+
}
|
83 |
+
|
84 |
+
__device__ void bwdLambdaGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)
|
85 |
+
{
|
86 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
87 |
+
float cosThetaSqr = _cosTheta * _cosTheta;
|
88 |
+
float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;
|
89 |
+
float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);
|
90 |
+
|
91 |
+
d_alphaSqr += d_out * (0.25 * tanThetaSqr) / sqrtf(alphaSqr * tanThetaSqr + 1.0f);
|
92 |
+
if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
|
93 |
+
d_cosTheta += d_out * -(0.5 * alphaSqr) / (powf(_cosTheta, 3.0f) * sqrtf(alphaSqr / cosThetaSqr - alphaSqr + 1.0f));
|
94 |
+
}
|
95 |
+
|
96 |
+
//------------------------------------------------------------------------
|
97 |
+
// Masking GGX
|
98 |
+
|
99 |
+
__device__ float fwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO)
|
100 |
+
{
|
101 |
+
float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);
|
102 |
+
float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);
|
103 |
+
return 1.0f / (1.0f + lambdaI + lambdaO);
|
104 |
+
}
|
105 |
+
|
106 |
+
__device__ void bwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO, float& d_alphaSqr, float& d_cosThetaI, float& d_cosThetaO, const float d_out)
|
107 |
+
{
|
108 |
+
// FWD eval
|
109 |
+
float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);
|
110 |
+
float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);
|
111 |
+
|
112 |
+
// BWD eval
|
113 |
+
float d_lambdaIO = -d_out / powf(1.0f + lambdaI + lambdaO, 2.0f);
|
114 |
+
bwdLambdaGGX(alphaSqr, cosThetaI, d_alphaSqr, d_cosThetaI, d_lambdaIO);
|
115 |
+
bwdLambdaGGX(alphaSqr, cosThetaO, d_alphaSqr, d_cosThetaO, d_lambdaIO);
|
116 |
+
}
|
117 |
+
|
118 |
+
//------------------------------------------------------------------------
|
119 |
+
// GGX specular
|
120 |
+
|
121 |
+
__device__ vec3f fwdPbrSpecular(const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness)
|
122 |
+
{
|
123 |
+
float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);
|
124 |
+
float alphaSqr = _alpha * _alpha;
|
125 |
+
|
126 |
+
vec3f h = safeNormalize(wo + wi);
|
127 |
+
float woDotN = dot(wo, nrm);
|
128 |
+
float wiDotN = dot(wi, nrm);
|
129 |
+
float woDotH = dot(wo, h);
|
130 |
+
float nDotH = dot(nrm, h);
|
131 |
+
|
132 |
+
float D = fwdNdfGGX(alphaSqr, nDotH);
|
133 |
+
float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);
|
134 |
+
vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH);
|
135 |
+
vec3f w = F * D * G * 0.25 / woDotN;
|
136 |
+
|
137 |
+
bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);
|
138 |
+
return frontfacing ? w : 0.0f;
|
139 |
+
}
|
140 |
+
|
141 |
+
__device__ void bwdPbrSpecular(
|
142 |
+
const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness,
|
143 |
+
vec3f& d_col, vec3f& d_nrm, vec3f& d_wo, vec3f& d_wi, float& d_alpha, const vec3f d_out)
|
144 |
+
{
|
145 |
+
///////////////////////////////////////////////////////////////////////
|
146 |
+
// FWD eval
|
147 |
+
|
148 |
+
float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);
|
149 |
+
float alphaSqr = _alpha * _alpha;
|
150 |
+
|
151 |
+
vec3f h = safeNormalize(wo + wi);
|
152 |
+
float woDotN = dot(wo, nrm);
|
153 |
+
float wiDotN = dot(wi, nrm);
|
154 |
+
float woDotH = dot(wo, h);
|
155 |
+
float nDotH = dot(nrm, h);
|
156 |
+
|
157 |
+
float D = fwdNdfGGX(alphaSqr, nDotH);
|
158 |
+
float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);
|
159 |
+
vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH);
|
160 |
+
vec3f w = F * D * G * 0.25 / woDotN;
|
161 |
+
bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);
|
162 |
+
|
163 |
+
if (frontfacing)
|
164 |
+
{
|
165 |
+
///////////////////////////////////////////////////////////////////////
|
166 |
+
// BWD eval
|
167 |
+
|
168 |
+
vec3f d_F = d_out * D * G * 0.25f / woDotN;
|
169 |
+
float d_D = sum(d_out * F * G * 0.25f / woDotN);
|
170 |
+
float d_G = sum(d_out * F * D * 0.25f / woDotN);
|
171 |
+
|
172 |
+
float d_woDotN = -sum(d_out * F * D * G * 0.25f / (woDotN * woDotN));
|
173 |
+
|
174 |
+
vec3f d_f90(0);
|
175 |
+
float d_woDotH(0), d_wiDotN(0), d_nDotH(0), d_alphaSqr(0);
|
176 |
+
bwdFresnelSchlick(col, 1.0f, woDotH, d_col, d_f90, d_woDotH, d_F);
|
177 |
+
bwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN, d_alphaSqr, d_woDotN, d_wiDotN, d_G);
|
178 |
+
bwdNdfGGX(alphaSqr, nDotH, d_alphaSqr, d_nDotH, d_D);
|
179 |
+
|
180 |
+
vec3f d_h(0);
|
181 |
+
bwdDot(nrm, h, d_nrm, d_h, d_nDotH);
|
182 |
+
bwdDot(wo, h, d_wo, d_h, d_woDotH);
|
183 |
+
bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN);
|
184 |
+
bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN);
|
185 |
+
|
186 |
+
vec3f d_h_unnorm(0);
|
187 |
+
bwdSafeNormalize(wo + wi, d_h_unnorm, d_h);
|
188 |
+
d_wo += d_h_unnorm;
|
189 |
+
d_wi += d_h_unnorm;
|
190 |
+
|
191 |
+
if (alpha > min_roughness * min_roughness)
|
192 |
+
d_alpha += d_alphaSqr * 2 * alpha;
|
193 |
+
}
|
194 |
+
}
|
195 |
+
|
196 |
+
//------------------------------------------------------------------------
|
197 |
+
// Full PBR BSDF
|
198 |
+
|
199 |
+
__device__ vec3f fwdPbrBSDF(const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness)
|
200 |
+
{
|
201 |
+
vec3f wo = safeNormalize(view_pos - pos);
|
202 |
+
vec3f wi = safeNormalize(light_pos - pos);
|
203 |
+
|
204 |
+
float alpha = arm.y * arm.y;
|
205 |
+
vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);
|
206 |
+
vec3f diff_col = kd * (1.0f - arm.z);
|
207 |
+
|
208 |
+
float lambert = fwdLambert(nrm, wi);
|
209 |
+
vec3f diffuse = diff_col * lambert;
|
210 |
+
vec3f specular = fwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness);
|
211 |
+
|
212 |
+
return diffuse + specular;
|
213 |
+
}
|
214 |
+
|
215 |
+
__device__ void bwdPbrBSDF(
|
216 |
+
const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness,
|
217 |
+
vec3f& d_kd, vec3f& d_arm, vec3f& d_pos, vec3f& d_nrm, vec3f& d_view_pos, vec3f& d_light_pos, const vec3f d_out)
|
218 |
+
{
|
219 |
+
////////////////////////////////////////////////////////////////////////
|
220 |
+
// FWD
|
221 |
+
vec3f _wi = light_pos - pos;
|
222 |
+
vec3f _wo = view_pos - pos;
|
223 |
+
vec3f wi = safeNormalize(_wi);
|
224 |
+
vec3f wo = safeNormalize(_wo);
|
225 |
+
|
226 |
+
float alpha = arm.y * arm.y;
|
227 |
+
vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);
|
228 |
+
vec3f diff_col = kd * (1.0f - arm.z);
|
229 |
+
float lambert = fwdLambert(nrm, wi);
|
230 |
+
|
231 |
+
////////////////////////////////////////////////////////////////////////
|
232 |
+
// BWD
|
233 |
+
|
234 |
+
float d_alpha(0);
|
235 |
+
vec3f d_spec_col(0), d_wi(0), d_wo(0);
|
236 |
+
bwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness, d_spec_col, d_nrm, d_wo, d_wi, d_alpha, d_out);
|
237 |
+
|
238 |
+
float d_lambert = sum(diff_col * d_out);
|
239 |
+
bwdLambert(nrm, wi, d_nrm, d_wi, d_lambert);
|
240 |
+
|
241 |
+
// Backprop: diff_col = kd * (1.0f - arm.z)
|
242 |
+
vec3f d_diff_col = d_out * lambert;
|
243 |
+
d_kd += d_diff_col * (1.0f - arm.z);
|
244 |
+
d_arm.z -= sum(d_diff_col * kd);
|
245 |
+
|
246 |
+
// Backprop: spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x)
|
247 |
+
d_kd -= d_spec_col * (arm.x - 1.0f) * arm.z;
|
248 |
+
d_arm.x += sum(d_spec_col * (arm.z * (0.04f - kd) - 0.04f));
|
249 |
+
d_arm.z -= sum(d_spec_col * (kd - 0.04f) * (arm.x - 1.0f));
|
250 |
+
|
251 |
+
// Backprop: alpha = arm.y * arm.y
|
252 |
+
d_arm.y += d_alpha * 2 * arm.y;
|
253 |
+
|
254 |
+
// Backprop: vec3f wi = safeNormalize(light_pos - pos);
|
255 |
+
vec3f d__wi(0);
|
256 |
+
bwdSafeNormalize(_wi, d__wi, d_wi);
|
257 |
+
d_light_pos += d__wi;
|
258 |
+
d_pos -= d__wi;
|
259 |
+
|
260 |
+
// Backprop: vec3f wo = safeNormalize(view_pos - pos);
|
261 |
+
vec3f d__wo(0);
|
262 |
+
bwdSafeNormalize(_wo, d__wo, d_wo);
|
263 |
+
d_view_pos += d__wo;
|
264 |
+
d_pos -= d__wo;
|
265 |
+
}
|
266 |
+
|
267 |
+
//------------------------------------------------------------------------
|
268 |
+
// Kernels
|
269 |
+
|
270 |
+
__global__ void LambertFwdKernel(LambertKernelParams p)
|
271 |
+
{
|
272 |
+
// Calculate pixel position.
|
273 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
274 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
275 |
+
unsigned int pz = blockIdx.z;
|
276 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
277 |
+
return;
|
278 |
+
|
279 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
280 |
+
vec3f wi = p.wi.fetch3(px, py, pz);
|
281 |
+
|
282 |
+
float res = fwdLambert(nrm, wi);
|
283 |
+
|
284 |
+
p.out.store(px, py, pz, res);
|
285 |
+
}
|
286 |
+
|
287 |
+
__global__ void LambertBwdKernel(LambertKernelParams p)
|
288 |
+
{
|
289 |
+
// Calculate pixel position.
|
290 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
291 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
292 |
+
unsigned int pz = blockIdx.z;
|
293 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
294 |
+
return;
|
295 |
+
|
296 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
297 |
+
vec3f wi = p.wi.fetch3(px, py, pz);
|
298 |
+
float d_out = p.out.fetch1(px, py, pz);
|
299 |
+
|
300 |
+
vec3f d_nrm(0), d_wi(0);
|
301 |
+
bwdLambert(nrm, wi, d_nrm, d_wi, d_out);
|
302 |
+
|
303 |
+
p.nrm.store_grad(px, py, pz, d_nrm);
|
304 |
+
p.wi.store_grad(px, py, pz, d_wi);
|
305 |
+
}
|
306 |
+
|
307 |
+
__global__ void FresnelShlickFwdKernel(FresnelShlickKernelParams p)
|
308 |
+
{
|
309 |
+
// Calculate pixel position.
|
310 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
311 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
312 |
+
unsigned int pz = blockIdx.z;
|
313 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
314 |
+
return;
|
315 |
+
|
316 |
+
vec3f f0 = p.f0.fetch3(px, py, pz);
|
317 |
+
vec3f f90 = p.f90.fetch3(px, py, pz);
|
318 |
+
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
319 |
+
|
320 |
+
vec3f res = fwdFresnelSchlick(f0, f90, cosTheta);
|
321 |
+
p.out.store(px, py, pz, res);
|
322 |
+
}
|
323 |
+
|
324 |
+
__global__ void FresnelShlickBwdKernel(FresnelShlickKernelParams p)
|
325 |
+
{
|
326 |
+
// Calculate pixel position.
|
327 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
328 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
329 |
+
unsigned int pz = blockIdx.z;
|
330 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
331 |
+
return;
|
332 |
+
|
333 |
+
vec3f f0 = p.f0.fetch3(px, py, pz);
|
334 |
+
vec3f f90 = p.f90.fetch3(px, py, pz);
|
335 |
+
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
336 |
+
vec3f d_out = p.out.fetch3(px, py, pz);
|
337 |
+
|
338 |
+
vec3f d_f0(0), d_f90(0);
|
339 |
+
float d_cosTheta(0);
|
340 |
+
bwdFresnelSchlick(f0, f90, cosTheta, d_f0, d_f90, d_cosTheta, d_out);
|
341 |
+
|
342 |
+
p.f0.store_grad(px, py, pz, d_f0);
|
343 |
+
p.f90.store_grad(px, py, pz, d_f90);
|
344 |
+
p.cosTheta.store_grad(px, py, pz, d_cosTheta);
|
345 |
+
}
|
346 |
+
|
347 |
+
__global__ void ndfGGXFwdKernel(NdfGGXParams p)
|
348 |
+
{
|
349 |
+
// Calculate pixel position.
|
350 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
351 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
352 |
+
unsigned int pz = blockIdx.z;
|
353 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
354 |
+
return;
|
355 |
+
|
356 |
+
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
357 |
+
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
358 |
+
float res = fwdNdfGGX(alphaSqr, cosTheta);
|
359 |
+
|
360 |
+
p.out.store(px, py, pz, res);
|
361 |
+
}
|
362 |
+
|
363 |
+
__global__ void ndfGGXBwdKernel(NdfGGXParams p)
|
364 |
+
{
|
365 |
+
// Calculate pixel position.
|
366 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
367 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
368 |
+
unsigned int pz = blockIdx.z;
|
369 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
370 |
+
return;
|
371 |
+
|
372 |
+
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
373 |
+
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
374 |
+
float d_out = p.out.fetch1(px, py, pz);
|
375 |
+
|
376 |
+
float d_alphaSqr(0), d_cosTheta(0);
|
377 |
+
bwdNdfGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out);
|
378 |
+
|
379 |
+
p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
|
380 |
+
p.cosTheta.store_grad(px, py, pz, d_cosTheta);
|
381 |
+
}
|
382 |
+
|
383 |
+
__global__ void lambdaGGXFwdKernel(NdfGGXParams p)
|
384 |
+
{
|
385 |
+
// Calculate pixel position.
|
386 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
387 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
388 |
+
unsigned int pz = blockIdx.z;
|
389 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
390 |
+
return;
|
391 |
+
|
392 |
+
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
393 |
+
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
394 |
+
float res = fwdLambdaGGX(alphaSqr, cosTheta);
|
395 |
+
|
396 |
+
p.out.store(px, py, pz, res);
|
397 |
+
}
|
398 |
+
|
399 |
+
__global__ void lambdaGGXBwdKernel(NdfGGXParams p)
|
400 |
+
{
|
401 |
+
// Calculate pixel position.
|
402 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
403 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
404 |
+
unsigned int pz = blockIdx.z;
|
405 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
406 |
+
return;
|
407 |
+
|
408 |
+
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
409 |
+
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
410 |
+
float d_out = p.out.fetch1(px, py, pz);
|
411 |
+
|
412 |
+
float d_alphaSqr(0), d_cosTheta(0);
|
413 |
+
bwdLambdaGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out);
|
414 |
+
|
415 |
+
p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
|
416 |
+
p.cosTheta.store_grad(px, py, pz, d_cosTheta);
|
417 |
+
}
|
418 |
+
|
419 |
+
__global__ void maskingSmithFwdKernel(MaskingSmithParams p)
|
420 |
+
{
|
421 |
+
// Calculate pixel position.
|
422 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
423 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
424 |
+
unsigned int pz = blockIdx.z;
|
425 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
426 |
+
return;
|
427 |
+
|
428 |
+
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
429 |
+
float cosThetaI = p.cosThetaI.fetch1(px, py, pz);
|
430 |
+
float cosThetaO = p.cosThetaO.fetch1(px, py, pz);
|
431 |
+
float res = fwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO);
|
432 |
+
|
433 |
+
p.out.store(px, py, pz, res);
|
434 |
+
}
|
435 |
+
|
436 |
+
__global__ void maskingSmithBwdKernel(MaskingSmithParams p)
|
437 |
+
{
|
438 |
+
// Calculate pixel position.
|
439 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
440 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
441 |
+
unsigned int pz = blockIdx.z;
|
442 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
443 |
+
return;
|
444 |
+
|
445 |
+
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
446 |
+
float cosThetaI = p.cosThetaI.fetch1(px, py, pz);
|
447 |
+
float cosThetaO = p.cosThetaO.fetch1(px, py, pz);
|
448 |
+
float d_out = p.out.fetch1(px, py, pz);
|
449 |
+
|
450 |
+
float d_alphaSqr(0), d_cosThetaI(0), d_cosThetaO(0);
|
451 |
+
bwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO, d_alphaSqr, d_cosThetaI, d_cosThetaO, d_out);
|
452 |
+
|
453 |
+
p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
|
454 |
+
p.cosThetaI.store_grad(px, py, pz, d_cosThetaI);
|
455 |
+
p.cosThetaO.store_grad(px, py, pz, d_cosThetaO);
|
456 |
+
}
|
457 |
+
|
458 |
+
__global__ void pbrSpecularFwdKernel(PbrSpecular p)
|
459 |
+
{
|
460 |
+
// Calculate pixel position.
|
461 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
462 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
463 |
+
unsigned int pz = blockIdx.z;
|
464 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
465 |
+
return;
|
466 |
+
|
467 |
+
vec3f col = p.col.fetch3(px, py, pz);
|
468 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
469 |
+
vec3f wo = p.wo.fetch3(px, py, pz);
|
470 |
+
vec3f wi = p.wi.fetch3(px, py, pz);
|
471 |
+
float alpha = p.alpha.fetch1(px, py, pz);
|
472 |
+
|
473 |
+
vec3f res = fwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness);
|
474 |
+
|
475 |
+
p.out.store(px, py, pz, res);
|
476 |
+
}
|
477 |
+
|
478 |
+
__global__ void pbrSpecularBwdKernel(PbrSpecular p)
|
479 |
+
{
|
480 |
+
// Calculate pixel position.
|
481 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
482 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
483 |
+
unsigned int pz = blockIdx.z;
|
484 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
485 |
+
return;
|
486 |
+
|
487 |
+
vec3f col = p.col.fetch3(px, py, pz);
|
488 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
489 |
+
vec3f wo = p.wo.fetch3(px, py, pz);
|
490 |
+
vec3f wi = p.wi.fetch3(px, py, pz);
|
491 |
+
float alpha = p.alpha.fetch1(px, py, pz);
|
492 |
+
vec3f d_out = p.out.fetch3(px, py, pz);
|
493 |
+
|
494 |
+
float d_alpha(0);
|
495 |
+
vec3f d_col(0), d_nrm(0), d_wo(0), d_wi(0);
|
496 |
+
bwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness, d_col, d_nrm, d_wo, d_wi, d_alpha, d_out);
|
497 |
+
|
498 |
+
p.col.store_grad(px, py, pz, d_col);
|
499 |
+
p.nrm.store_grad(px, py, pz, d_nrm);
|
500 |
+
p.wo.store_grad(px, py, pz, d_wo);
|
501 |
+
p.wi.store_grad(px, py, pz, d_wi);
|
502 |
+
p.alpha.store_grad(px, py, pz, d_alpha);
|
503 |
+
}
|
504 |
+
|
505 |
+
__global__ void pbrBSDFFwdKernel(PbrBSDF p)
|
506 |
+
{
|
507 |
+
// Calculate pixel position.
|
508 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
509 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
510 |
+
unsigned int pz = blockIdx.z;
|
511 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
512 |
+
return;
|
513 |
+
|
514 |
+
vec3f kd = p.kd.fetch3(px, py, pz);
|
515 |
+
vec3f arm = p.arm.fetch3(px, py, pz);
|
516 |
+
vec3f pos = p.pos.fetch3(px, py, pz);
|
517 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
518 |
+
vec3f view_pos = p.view_pos.fetch3(px, py, pz);
|
519 |
+
vec3f light_pos = p.light_pos.fetch3(px, py, pz);
|
520 |
+
|
521 |
+
vec3f res = fwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness);
|
522 |
+
|
523 |
+
p.out.store(px, py, pz, res);
|
524 |
+
}
|
525 |
+
__global__ void pbrBSDFBwdKernel(PbrBSDF p)
|
526 |
+
{
|
527 |
+
// Calculate pixel position.
|
528 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
529 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
530 |
+
unsigned int pz = blockIdx.z;
|
531 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
532 |
+
return;
|
533 |
+
|
534 |
+
vec3f kd = p.kd.fetch3(px, py, pz);
|
535 |
+
vec3f arm = p.arm.fetch3(px, py, pz);
|
536 |
+
vec3f pos = p.pos.fetch3(px, py, pz);
|
537 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
538 |
+
vec3f view_pos = p.view_pos.fetch3(px, py, pz);
|
539 |
+
vec3f light_pos = p.light_pos.fetch3(px, py, pz);
|
540 |
+
vec3f d_out = p.out.fetch3(px, py, pz);
|
541 |
+
|
542 |
+
vec3f d_kd(0), d_arm(0), d_pos(0), d_nrm(0), d_view_pos(0), d_light_pos(0);
|
543 |
+
bwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, d_kd, d_arm, d_pos, d_nrm, d_view_pos, d_light_pos, d_out);
|
544 |
+
|
545 |
+
p.kd.store_grad(px, py, pz, d_kd);
|
546 |
+
p.arm.store_grad(px, py, pz, d_arm);
|
547 |
+
p.pos.store_grad(px, py, pz, d_pos);
|
548 |
+
p.nrm.store_grad(px, py, pz, d_nrm);
|
549 |
+
p.view_pos.store_grad(px, py, pz, d_view_pos);
|
550 |
+
p.light_pos.store_grad(px, py, pz, d_light_pos);
|
551 |
+
}
|
nvdiffmodeling/src/renderutils/c_src/bsdf.h
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#pragma once
|
10 |
+
|
11 |
+
#include "common.h"
|
12 |
+
|
13 |
+
struct LambertKernelParams
|
14 |
+
{
|
15 |
+
Tensor nrm;
|
16 |
+
Tensor wi;
|
17 |
+
Tensor out;
|
18 |
+
dim3 gridSize;
|
19 |
+
};
|
20 |
+
|
21 |
+
struct FresnelShlickKernelParams
|
22 |
+
{
|
23 |
+
Tensor f0;
|
24 |
+
Tensor f90;
|
25 |
+
Tensor cosTheta;
|
26 |
+
Tensor out;
|
27 |
+
dim3 gridSize;
|
28 |
+
};
|
29 |
+
|
30 |
+
struct NdfGGXParams
|
31 |
+
{
|
32 |
+
Tensor alphaSqr;
|
33 |
+
Tensor cosTheta;
|
34 |
+
Tensor out;
|
35 |
+
dim3 gridSize;
|
36 |
+
};
|
37 |
+
|
38 |
+
struct MaskingSmithParams
|
39 |
+
{
|
40 |
+
Tensor alphaSqr;
|
41 |
+
Tensor cosThetaI;
|
42 |
+
Tensor cosThetaO;
|
43 |
+
Tensor out;
|
44 |
+
dim3 gridSize;
|
45 |
+
};
|
46 |
+
|
47 |
+
struct PbrSpecular
|
48 |
+
{
|
49 |
+
Tensor col;
|
50 |
+
Tensor nrm;
|
51 |
+
Tensor wo;
|
52 |
+
Tensor wi;
|
53 |
+
Tensor alpha;
|
54 |
+
Tensor out;
|
55 |
+
dim3 gridSize;
|
56 |
+
float min_roughness;
|
57 |
+
};
|
58 |
+
|
59 |
+
struct PbrBSDF
|
60 |
+
{
|
61 |
+
Tensor kd;
|
62 |
+
Tensor arm;
|
63 |
+
Tensor pos;
|
64 |
+
Tensor nrm;
|
65 |
+
Tensor view_pos;
|
66 |
+
Tensor light_pos;
|
67 |
+
Tensor out;
|
68 |
+
dim3 gridSize;
|
69 |
+
float min_roughness;
|
70 |
+
};
|
nvdiffmodeling/src/renderutils/c_src/common.cpp
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <cuda_runtime.h>
|
10 |
+
#include <algorithm>
|
11 |
+
|
12 |
+
//------------------------------------------------------------------------
|
13 |
+
// Block and grid size calculators for kernel launches.
|
14 |
+
|
15 |
+
dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims)
|
16 |
+
{
|
17 |
+
int maxThreads = maxWidth * maxHeight;
|
18 |
+
if (maxThreads <= 1 || (dims.x * dims.y) <= 1)
|
19 |
+
return dim3(1, 1, 1); // Degenerate.
|
20 |
+
|
21 |
+
// Start from max size.
|
22 |
+
int bw = maxWidth;
|
23 |
+
int bh = maxHeight;
|
24 |
+
|
25 |
+
// Optimizations for weirdly sized buffers.
|
26 |
+
if (dims.x < bw)
|
27 |
+
{
|
28 |
+
// Decrease block width to smallest power of two that covers the buffer width.
|
29 |
+
while ((bw >> 1) >= dims.x)
|
30 |
+
bw >>= 1;
|
31 |
+
|
32 |
+
// Maximize height.
|
33 |
+
bh = maxThreads / bw;
|
34 |
+
if (bh > dims.y)
|
35 |
+
bh = dims.y;
|
36 |
+
}
|
37 |
+
else if (dims.y < bh)
|
38 |
+
{
|
39 |
+
// Halve height and double width until fits completely inside buffer vertically.
|
40 |
+
while (bh > dims.y)
|
41 |
+
{
|
42 |
+
bh >>= 1;
|
43 |
+
if (bw < dims.x)
|
44 |
+
bw <<= 1;
|
45 |
+
}
|
46 |
+
}
|
47 |
+
|
48 |
+
// Done.
|
49 |
+
return dim3(bw, bh, 1);
|
50 |
+
}
|
51 |
+
|
52 |
+
// returns the size of a block that can be reduced using horizontal SIMD operations (e.g. __shfl_xor_sync)
|
53 |
+
dim3 getWarpSize(dim3 blockSize)
|
54 |
+
{
|
55 |
+
return dim3(
|
56 |
+
std::min(blockSize.x, 32u),
|
57 |
+
std::min(std::max(32u / blockSize.x, 1u), std::min(32u, blockSize.y)),
|
58 |
+
std::min(std::max(32u / (blockSize.x * blockSize.y), 1u), std::min(32u, blockSize.z))
|
59 |
+
);
|
60 |
+
}
|
61 |
+
|
62 |
+
dim3 getLaunchGridSize(dim3 blockSize, dim3 dims)
|
63 |
+
{
|
64 |
+
dim3 gridSize;
|
65 |
+
gridSize.x = (dims.x - 1) / blockSize.x + 1;
|
66 |
+
gridSize.y = (dims.y - 1) / blockSize.y + 1;
|
67 |
+
gridSize.z = (dims.z - 1) / blockSize.z + 1;
|
68 |
+
return gridSize;
|
69 |
+
}
|
70 |
+
|
71 |
+
//------------------------------------------------------------------------
|
nvdiffmodeling/src/renderutils/c_src/common.h
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#pragma once
|
10 |
+
#include <cuda.h>
|
11 |
+
#include <stdint.h>
|
12 |
+
|
13 |
+
#include "vec3f.h"
|
14 |
+
#include "vec4f.h"
|
15 |
+
#include "tensor.h"
|
16 |
+
|
17 |
+
dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims);
|
18 |
+
dim3 getLaunchGridSize(dim3 blockSize, dim3 dims);
|
19 |
+
|
20 |
+
#ifdef __CUDACC__
|
21 |
+
|
22 |
+
#ifdef _MSC_VER
|
23 |
+
#define M_PI 3.14159265358979323846f
|
24 |
+
#endif
|
25 |
+
|
26 |
+
__host__ __device__ static inline dim3 getWarpSize(dim3 blockSize)
|
27 |
+
{
|
28 |
+
return dim3(
|
29 |
+
min(blockSize.x, 32u),
|
30 |
+
min(max(32u / blockSize.x, 1u), min(32u, blockSize.y)),
|
31 |
+
min(max(32u / (blockSize.x * blockSize.y), 1u), min(32u, blockSize.z))
|
32 |
+
);
|
33 |
+
}
|
34 |
+
|
35 |
+
__device__ static inline float clamp(float val, float mn, float mx) { return min(max(val, mn), mx); }
|
36 |
+
#else
|
37 |
+
dim3 getWarpSize(dim3 blockSize);
|
38 |
+
#endif
|
nvdiffmodeling/src/renderutils/c_src/loss.cu
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <cuda.h>
|
10 |
+
|
11 |
+
#include "common.h"
|
12 |
+
#include "loss.h"
|
13 |
+
|
14 |
+
//------------------------------------------------------------------------
|
15 |
+
// Utils
|
16 |
+
|
17 |
+
__device__ inline float bwdAbs(float x) { return x == 0.0f ? 0.0f : x < 0.0f ? -1.0f : 1.0f; }
|
18 |
+
|
19 |
+
__device__ float warpSum(float val) {
|
20 |
+
for (int i = 1; i < 32; i *= 2)
|
21 |
+
val += __shfl_xor_sync(0xFFFFFFFF, val, i);
|
22 |
+
return val;
|
23 |
+
}
|
24 |
+
|
25 |
+
//------------------------------------------------------------------------
|
26 |
+
// Tonemapping
|
27 |
+
|
28 |
+
__device__ inline float fwdSRGB(float x)
|
29 |
+
{
|
30 |
+
return x > 0.0031308f ? powf(max(x, 0.0031308f), 1.0f / 2.4f) * 1.055f - 0.055f : 12.92f * max(x, 0.0f);
|
31 |
+
}
|
32 |
+
|
33 |
+
__device__ inline void bwdSRGB(float x, float &d_x, float d_out)
|
34 |
+
{
|
35 |
+
if (x > 0.0031308f)
|
36 |
+
d_x += d_out * 0.439583f / powf(x, 0.583333f);
|
37 |
+
else if (x > 0.0f)
|
38 |
+
d_x += d_out * 12.92f;
|
39 |
+
}
|
40 |
+
|
41 |
+
__device__ inline vec3f fwdTonemapLogSRGB(vec3f x)
|
42 |
+
{
|
43 |
+
return vec3f(fwdSRGB(logf(x.x + 1.0f)), fwdSRGB(logf(x.y + 1.0f)), fwdSRGB(logf(x.z + 1.0f)));
|
44 |
+
}
|
45 |
+
|
46 |
+
__device__ inline void bwdTonemapLogSRGB(vec3f x, vec3f& d_x, vec3f d_out)
|
47 |
+
{
|
48 |
+
if (x.x > 0.0f && x.x < 65535.0f)
|
49 |
+
{
|
50 |
+
bwdSRGB(logf(x.x + 1.0f), d_x.x, d_out.x);
|
51 |
+
d_x.x *= 1 / (x.x + 1.0f);
|
52 |
+
}
|
53 |
+
if (x.y > 0.0f && x.y < 65535.0f)
|
54 |
+
{
|
55 |
+
bwdSRGB(logf(x.y + 1.0f), d_x.y, d_out.y);
|
56 |
+
d_x.y *= 1 / (x.y + 1.0f);
|
57 |
+
}
|
58 |
+
if (x.z > 0.0f && x.z < 65535.0f)
|
59 |
+
{
|
60 |
+
bwdSRGB(logf(x.z + 1.0f), d_x.z, d_out.z);
|
61 |
+
d_x.z *= 1 / (x.z + 1.0f);
|
62 |
+
}
|
63 |
+
}
|
64 |
+
|
65 |
+
__device__ inline float fwdRELMSE(float img, float target, float eps = 0.1f)
|
66 |
+
{
|
67 |
+
return (img - target) * (img - target) / (img * img + target * target + eps);
|
68 |
+
}
|
69 |
+
|
70 |
+
__device__ inline void bwdRELMSE(float img, float target, float &d_img, float &d_target, float d_out, float eps = 0.1f)
|
71 |
+
{
|
72 |
+
float denom = (target * target + img * img + eps);
|
73 |
+
d_img += d_out * 2 * (img - target) * (target * (target + img) + eps) / (denom * denom);
|
74 |
+
d_target -= d_out * 2 * (img - target) * (img * (target + img) + eps) / (denom * denom);
|
75 |
+
}
|
76 |
+
|
77 |
+
__device__ inline float fwdSMAPE(float img, float target, float eps=0.01f)
|
78 |
+
{
|
79 |
+
return abs(img - target) / (img + target + eps);
|
80 |
+
}
|
81 |
+
|
82 |
+
__device__ inline void bwdSMAPE(float img, float target, float& d_img, float& d_target, float d_out, float eps = 0.01f)
|
83 |
+
{
|
84 |
+
float denom = (target + img + eps);
|
85 |
+
d_img += d_out * bwdAbs(img - target) * (2 * target + eps) / (denom * denom);
|
86 |
+
d_target -= d_out * bwdAbs(img - target) * (2 * img + eps) / (denom * denom);
|
87 |
+
}
|
88 |
+
|
89 |
+
//------------------------------------------------------------------------
|
90 |
+
// Kernels
|
91 |
+
|
92 |
+
__global__ void imgLossFwdKernel(LossKernelParams p)
|
93 |
+
{
|
94 |
+
// Calculate pixel position.
|
95 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
96 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
97 |
+
unsigned int pz = blockIdx.z;
|
98 |
+
|
99 |
+
float floss = 0.0f;
|
100 |
+
if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z)
|
101 |
+
{
|
102 |
+
vec3f img = p.img.fetch3(px, py, pz);
|
103 |
+
vec3f target = p.target.fetch3(px, py, pz);
|
104 |
+
|
105 |
+
img = vec3f(clamp(img.x, 0.0f, 65535.0f), clamp(img.y, 0.0f, 65535.0f), clamp(img.z, 0.0f, 65535.0f));
|
106 |
+
target = vec3f(clamp(target.x, 0.0f, 65535.0f), clamp(target.y, 0.0f, 65535.0f), clamp(target.z, 0.0f, 65535.0f));
|
107 |
+
|
108 |
+
if (p.tonemapper == TONEMAPPER_LOG_SRGB)
|
109 |
+
{
|
110 |
+
img = fwdTonemapLogSRGB(img);
|
111 |
+
target = fwdTonemapLogSRGB(target);
|
112 |
+
}
|
113 |
+
|
114 |
+
vec3f vloss(0);
|
115 |
+
if (p.loss == LOSS_MSE)
|
116 |
+
vloss = (img - target) * (img - target);
|
117 |
+
else if (p.loss == LOSS_RELMSE)
|
118 |
+
vloss = vec3f(fwdRELMSE(img.x, target.x), fwdRELMSE(img.y, target.y), fwdRELMSE(img.z, target.z));
|
119 |
+
else if (p.loss == LOSS_SMAPE)
|
120 |
+
vloss = vec3f(fwdSMAPE(img.x, target.x), fwdSMAPE(img.y, target.y), fwdSMAPE(img.z, target.z));
|
121 |
+
else
|
122 |
+
vloss = vec3f(abs(img.x - target.x), abs(img.y - target.y), abs(img.z - target.z));
|
123 |
+
|
124 |
+
floss = sum(vloss) / 3.0f;
|
125 |
+
}
|
126 |
+
|
127 |
+
floss = warpSum(floss);
|
128 |
+
|
129 |
+
dim3 warpSize = getWarpSize(blockDim);
|
130 |
+
if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z && threadIdx.x % warpSize.x == 0 && threadIdx.y % warpSize.y == 0 && threadIdx.z % warpSize.z == 0)
|
131 |
+
p.out.store(px / warpSize.x, py / warpSize.y, pz / warpSize.z, floss);
|
132 |
+
}
|
133 |
+
|
134 |
+
__global__ void imgLossBwdKernel(LossKernelParams p)
|
135 |
+
{
|
136 |
+
// Calculate pixel position.
|
137 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
138 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
139 |
+
unsigned int pz = blockIdx.z;
|
140 |
+
|
141 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
142 |
+
return;
|
143 |
+
|
144 |
+
dim3 warpSize = getWarpSize(blockDim);
|
145 |
+
|
146 |
+
vec3f _img = p.img.fetch3(px, py, pz);
|
147 |
+
vec3f _target = p.target.fetch3(px, py, pz);
|
148 |
+
float d_out = p.out.fetch1(px / warpSize.x, py / warpSize.y, pz / warpSize.z);
|
149 |
+
|
150 |
+
/////////////////////////////////////////////////////////////////////
|
151 |
+
// FWD
|
152 |
+
|
153 |
+
vec3f img = _img, target = _target;
|
154 |
+
if (p.tonemapper == TONEMAPPER_LOG_SRGB)
|
155 |
+
{
|
156 |
+
img = fwdTonemapLogSRGB(img);
|
157 |
+
target = fwdTonemapLogSRGB(target);
|
158 |
+
}
|
159 |
+
|
160 |
+
/////////////////////////////////////////////////////////////////////
|
161 |
+
// BWD
|
162 |
+
|
163 |
+
vec3f d_vloss = vec3f(d_out, d_out, d_out) / 3.0f;
|
164 |
+
|
165 |
+
vec3f d_img(0), d_target(0);
|
166 |
+
if (p.loss == LOSS_MSE)
|
167 |
+
{
|
168 |
+
d_img = vec3f(d_vloss.x * 2 * (img.x - target.x), d_vloss.y * 2 * (img.y - target.y), d_vloss.x * 2 * (img.z - target.z));
|
169 |
+
d_target = -d_img;
|
170 |
+
}
|
171 |
+
else if (p.loss == LOSS_RELMSE)
|
172 |
+
{
|
173 |
+
bwdRELMSE(img.x, target.x, d_img.x, d_target.x, d_vloss.x);
|
174 |
+
bwdRELMSE(img.y, target.y, d_img.y, d_target.y, d_vloss.y);
|
175 |
+
bwdRELMSE(img.z, target.z, d_img.z, d_target.z, d_vloss.z);
|
176 |
+
}
|
177 |
+
else if (p.loss == LOSS_SMAPE)
|
178 |
+
{
|
179 |
+
bwdSMAPE(img.x, target.x, d_img.x, d_target.x, d_vloss.x);
|
180 |
+
bwdSMAPE(img.y, target.y, d_img.y, d_target.y, d_vloss.y);
|
181 |
+
bwdSMAPE(img.z, target.z, d_img.z, d_target.z, d_vloss.z);
|
182 |
+
}
|
183 |
+
else
|
184 |
+
{
|
185 |
+
d_img = d_vloss * vec3f(bwdAbs(img.x - target.x), bwdAbs(img.y - target.y), bwdAbs(img.z - target.z));
|
186 |
+
d_target = -d_img;
|
187 |
+
}
|
188 |
+
|
189 |
+
|
190 |
+
if (p.tonemapper == TONEMAPPER_LOG_SRGB)
|
191 |
+
{
|
192 |
+
vec3f d__img(0), d__target(0);
|
193 |
+
bwdTonemapLogSRGB(_img, d__img, d_img);
|
194 |
+
bwdTonemapLogSRGB(_target, d__target, d_target);
|
195 |
+
d_img = d__img; d_target = d__target;
|
196 |
+
}
|
197 |
+
|
198 |
+
if (_img.x <= 0.0f || _img.x >= 65535.0f) d_img.x = 0;
|
199 |
+
if (_img.y <= 0.0f || _img.y >= 65535.0f) d_img.y = 0;
|
200 |
+
if (_img.z <= 0.0f || _img.z >= 65535.0f) d_img.z = 0;
|
201 |
+
if (_target.x <= 0.0f || _target.x >= 65535.0f) d_target.x = 0;
|
202 |
+
if (_target.y <= 0.0f || _target.y >= 65535.0f) d_target.y = 0;
|
203 |
+
if (_target.z <= 0.0f || _target.z >= 65535.0f) d_target.z = 0;
|
204 |
+
|
205 |
+
p.img.store_grad(px, py, pz, d_img);
|
206 |
+
p.target.store_grad(px, py, pz, d_target);
|
207 |
+
}
|
nvdiffmodeling/src/renderutils/c_src/loss.h
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#pragma once
|
10 |
+
|
11 |
+
#include "common.h"
|
12 |
+
|
13 |
+
enum TonemapperType
|
14 |
+
{
|
15 |
+
TONEMAPPER_NONE = 0,
|
16 |
+
TONEMAPPER_LOG_SRGB = 1
|
17 |
+
};
|
18 |
+
|
19 |
+
enum LossType
|
20 |
+
{
|
21 |
+
LOSS_L1 = 0,
|
22 |
+
LOSS_MSE = 1,
|
23 |
+
LOSS_RELMSE = 2,
|
24 |
+
LOSS_SMAPE = 3
|
25 |
+
};
|
26 |
+
|
27 |
+
struct LossKernelParams
|
28 |
+
{
|
29 |
+
Tensor img;
|
30 |
+
Tensor target;
|
31 |
+
Tensor out;
|
32 |
+
dim3 gridSize;
|
33 |
+
TonemapperType tonemapper;
|
34 |
+
LossType loss;
|
35 |
+
};
|
nvdiffmodeling/src/renderutils/c_src/mesh.cu
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <cuda.h>
|
10 |
+
|
11 |
+
#include "common.h"
|
12 |
+
#include "mesh.h"
|
13 |
+
|
14 |
+
|
15 |
+
//------------------------------------------------------------------------
|
16 |
+
// Kernels
|
17 |
+
|
18 |
+
__global__ void xfmPointsFwdKernel(XfmKernelParams p)
|
19 |
+
{
|
20 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
21 |
+
unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z;
|
22 |
+
|
23 |
+
__shared__ float mtx[4][4];
|
24 |
+
if (threadIdx.x < 16)
|
25 |
+
mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0));
|
26 |
+
__syncthreads();
|
27 |
+
|
28 |
+
if (px >= p.gridSize.x)
|
29 |
+
return;
|
30 |
+
|
31 |
+
vec3f pos(
|
32 |
+
p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)),
|
33 |
+
p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)),
|
34 |
+
p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0))
|
35 |
+
);
|
36 |
+
|
37 |
+
if (p.isPoints)
|
38 |
+
{
|
39 |
+
p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0] + mtx[3][0]);
|
40 |
+
p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1] + mtx[3][1]);
|
41 |
+
p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2] + mtx[3][2]);
|
42 |
+
p.out.store(p.out.nhwcIndex(pz, px, 3, 0), pos.x * mtx[0][3] + pos.y * mtx[1][3] + pos.z * mtx[2][3] + mtx[3][3]);
|
43 |
+
}
|
44 |
+
else
|
45 |
+
{
|
46 |
+
p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0]);
|
47 |
+
p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1]);
|
48 |
+
p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2]);
|
49 |
+
}
|
50 |
+
}
|
51 |
+
|
52 |
+
__global__ void xfmPointsBwdKernel(XfmKernelParams p)
|
53 |
+
{
|
54 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
55 |
+
unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z;
|
56 |
+
|
57 |
+
__shared__ float mtx[4][4];
|
58 |
+
if (threadIdx.x < 16)
|
59 |
+
mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0));
|
60 |
+
__syncthreads();
|
61 |
+
|
62 |
+
if (px >= p.gridSize.x)
|
63 |
+
return;
|
64 |
+
|
65 |
+
vec3f pos(
|
66 |
+
p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)),
|
67 |
+
p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)),
|
68 |
+
p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0))
|
69 |
+
);
|
70 |
+
|
71 |
+
vec4f d_out(
|
72 |
+
p.out.fetch(p.out.nhwcIndex(pz, px, 0, 0)),
|
73 |
+
p.out.fetch(p.out.nhwcIndex(pz, px, 1, 0)),
|
74 |
+
p.out.fetch(p.out.nhwcIndex(pz, px, 2, 0)),
|
75 |
+
p.out.fetch(p.out.nhwcIndex(pz, px, 3, 0))
|
76 |
+
);
|
77 |
+
|
78 |
+
if (p.isPoints)
|
79 |
+
{
|
80 |
+
p.points.store_grad(p.points._nhwcIndex(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2] + d_out.w * mtx[0][3]);
|
81 |
+
p.points.store_grad(p.points._nhwcIndex(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2] + d_out.w * mtx[1][3]);
|
82 |
+
p.points.store_grad(p.points._nhwcIndex(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2] + d_out.w * mtx[2][3]);
|
83 |
+
}
|
84 |
+
else
|
85 |
+
{
|
86 |
+
p.points.store_grad(p.points._nhwcIndex(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2]);
|
87 |
+
p.points.store_grad(p.points._nhwcIndex(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2]);
|
88 |
+
p.points.store_grad(p.points._nhwcIndex(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2]);
|
89 |
+
}
|
90 |
+
}
|
nvdiffmodeling/src/renderutils/c_src/mesh.h
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#pragma once
|
10 |
+
|
11 |
+
#include "common.h"
|
12 |
+
|
13 |
+
struct XfmKernelParams
|
14 |
+
{
|
15 |
+
bool isPoints;
|
16 |
+
Tensor points;
|
17 |
+
Tensor matrix;
|
18 |
+
Tensor out;
|
19 |
+
dim3 gridSize;
|
20 |
+
};
|
nvdiffmodeling/src/renderutils/c_src/normal.cu
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include "common.h"
|
10 |
+
#include "normal.h"
|
11 |
+
|
12 |
+
#define NORMAL_THRESHOLD 0.1f
|
13 |
+
|
14 |
+
//------------------------------------------------------------------------
|
15 |
+
// Perturb shading normal by tangent frame
|
16 |
+
|
17 |
+
__device__ vec3f fwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, bool opengl)
|
18 |
+
{
|
19 |
+
vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm);
|
20 |
+
vec3f smooth_bitng = safeNormalize(_smooth_bitng);
|
21 |
+
vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f);
|
22 |
+
return safeNormalize(_shading_nrm);
|
23 |
+
}
|
24 |
+
|
25 |
+
__device__ void bwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, vec3f &d_perturbed_nrm, vec3f &d_smooth_nrm, vec3f &d_smooth_tng, const vec3f d_out, bool opengl)
|
26 |
+
{
|
27 |
+
////////////////////////////////////////////////////////////////////////
|
28 |
+
// FWD
|
29 |
+
vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm);
|
30 |
+
vec3f smooth_bitng = safeNormalize(_smooth_bitng);
|
31 |
+
vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f);
|
32 |
+
|
33 |
+
////////////////////////////////////////////////////////////////////////
|
34 |
+
// BWD
|
35 |
+
vec3f d_shading_nrm(0);
|
36 |
+
bwdSafeNormalize(_shading_nrm, d_shading_nrm, d_out);
|
37 |
+
|
38 |
+
vec3f d_smooth_bitng(0);
|
39 |
+
|
40 |
+
if (perturbed_nrm.z > 0.0f)
|
41 |
+
{
|
42 |
+
d_smooth_nrm += d_shading_nrm * perturbed_nrm.z;
|
43 |
+
d_perturbed_nrm.z += sum(d_shading_nrm * smooth_nrm);
|
44 |
+
}
|
45 |
+
|
46 |
+
d_smooth_bitng += (opengl ? -1 : 1) * d_shading_nrm * perturbed_nrm.y;
|
47 |
+
d_perturbed_nrm.y += (opengl ? -1 : 1) * sum(d_shading_nrm * smooth_bitng);
|
48 |
+
|
49 |
+
d_smooth_tng += d_shading_nrm * perturbed_nrm.x;
|
50 |
+
d_perturbed_nrm.x += sum(d_shading_nrm * smooth_tng);
|
51 |
+
|
52 |
+
vec3f d__smooth_bitng(0);
|
53 |
+
bwdSafeNormalize(_smooth_bitng, d__smooth_bitng, d_smooth_bitng);
|
54 |
+
|
55 |
+
bwdCross(smooth_tng, smooth_nrm, d_smooth_tng, d_smooth_nrm, d__smooth_bitng);
|
56 |
+
}
|
57 |
+
|
58 |
+
//------------------------------------------------------------------------
|
59 |
+
#define bent_nrm_eps 0.001f
|
60 |
+
|
61 |
+
__device__ vec3f fwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm)
|
62 |
+
{
|
63 |
+
float dp = dot(view_vec, smooth_nrm);
|
64 |
+
float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f);
|
65 |
+
return geom_nrm * (1.0f - t) + smooth_nrm * t;
|
66 |
+
}
|
67 |
+
|
68 |
+
__device__ void bwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm, vec3f& d_view_vec, vec3f& d_smooth_nrm, vec3f& d_geom_nrm, const vec3f d_out)
|
69 |
+
{
|
70 |
+
////////////////////////////////////////////////////////////////////////
|
71 |
+
// FWD
|
72 |
+
float dp = dot(view_vec, smooth_nrm);
|
73 |
+
float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f);
|
74 |
+
|
75 |
+
////////////////////////////////////////////////////////////////////////
|
76 |
+
// BWD
|
77 |
+
if (dp > NORMAL_THRESHOLD)
|
78 |
+
d_smooth_nrm += d_out;
|
79 |
+
else
|
80 |
+
{
|
81 |
+
// geom_nrm * (1.0f - t) + smooth_nrm * t;
|
82 |
+
d_geom_nrm += d_out * (1.0f - t);
|
83 |
+
d_smooth_nrm += d_out * t;
|
84 |
+
float d_t = sum(d_out * (smooth_nrm - geom_nrm));
|
85 |
+
|
86 |
+
float d_dp = dp < 0.0f || dp > NORMAL_THRESHOLD ? 0.0f : d_t / NORMAL_THRESHOLD;
|
87 |
+
|
88 |
+
bwdDot(view_vec, smooth_nrm, d_view_vec, d_smooth_nrm, d_dp);
|
89 |
+
}
|
90 |
+
}
|
91 |
+
|
92 |
+
//------------------------------------------------------------------------
|
93 |
+
// Kernels
|
94 |
+
|
95 |
+
__global__ void PrepareShadingNormalFwdKernel(PrepareShadingNormalKernelParams p)
|
96 |
+
{
|
97 |
+
// Calculate pixel position.
|
98 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
99 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
100 |
+
unsigned int pz = blockIdx.z;
|
101 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
102 |
+
return;
|
103 |
+
|
104 |
+
vec3f pos = p.pos.fetch3(px, py, pz);
|
105 |
+
vec3f view_pos = p.view_pos.fetch3(px, py, pz);
|
106 |
+
vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz);
|
107 |
+
vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz);
|
108 |
+
vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz);
|
109 |
+
vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz);
|
110 |
+
|
111 |
+
vec3f smooth_nrm = safeNormalize(_smooth_nrm);
|
112 |
+
vec3f smooth_tng = safeNormalize(_smooth_tng);
|
113 |
+
vec3f view_vec = safeNormalize(view_pos - pos);
|
114 |
+
vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl);
|
115 |
+
|
116 |
+
vec3f res;
|
117 |
+
if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f)
|
118 |
+
res = fwdBendNormal(view_vec, -shading_nrm, -geom_nrm);
|
119 |
+
else
|
120 |
+
res = fwdBendNormal(view_vec, shading_nrm, geom_nrm);
|
121 |
+
|
122 |
+
p.out.store(px, py, pz, res);
|
123 |
+
}
|
124 |
+
|
125 |
+
__global__ void PrepareShadingNormalBwdKernel(PrepareShadingNormalKernelParams p)
|
126 |
+
{
|
127 |
+
// Calculate pixel position.
|
128 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
129 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
130 |
+
unsigned int pz = blockIdx.z;
|
131 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
132 |
+
return;
|
133 |
+
|
134 |
+
vec3f pos = p.pos.fetch3(px, py, pz);
|
135 |
+
vec3f view_pos = p.view_pos.fetch3(px, py, pz);
|
136 |
+
vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz);
|
137 |
+
vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz);
|
138 |
+
vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz);
|
139 |
+
vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz);
|
140 |
+
vec3f d_out = p.out.fetch3(px, py, pz);
|
141 |
+
|
142 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
143 |
+
// FWD
|
144 |
+
|
145 |
+
vec3f smooth_nrm = safeNormalize(_smooth_nrm);
|
146 |
+
vec3f smooth_tng = safeNormalize(_smooth_tng);
|
147 |
+
vec3f _view_vec = view_pos - pos;
|
148 |
+
vec3f view_vec = safeNormalize(view_pos - pos);
|
149 |
+
|
150 |
+
vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl);
|
151 |
+
|
152 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
153 |
+
// BWD
|
154 |
+
|
155 |
+
vec3f d_view_vec(0), d_shading_nrm(0), d_geom_nrm(0);
|
156 |
+
if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f)
|
157 |
+
{
|
158 |
+
bwdBendNormal(view_vec, -shading_nrm, -geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out);
|
159 |
+
d_shading_nrm = -d_shading_nrm;
|
160 |
+
d_geom_nrm = -d_geom_nrm;
|
161 |
+
}
|
162 |
+
else
|
163 |
+
bwdBendNormal(view_vec, shading_nrm, geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out);
|
164 |
+
|
165 |
+
vec3f d_perturbed_nrm(0), d_smooth_nrm(0), d_smooth_tng(0);
|
166 |
+
bwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, d_perturbed_nrm, d_smooth_nrm, d_smooth_tng, d_shading_nrm, p.opengl);
|
167 |
+
|
168 |
+
vec3f d__view_vec(0), d__smooth_nrm(0), d__smooth_tng(0);
|
169 |
+
bwdSafeNormalize(_view_vec, d__view_vec, d_view_vec);
|
170 |
+
bwdSafeNormalize(_smooth_nrm, d__smooth_nrm, d_smooth_nrm);
|
171 |
+
bwdSafeNormalize(_smooth_tng, d__smooth_tng, d_smooth_tng);
|
172 |
+
|
173 |
+
p.pos.store_grad(px, py, pz, -d__view_vec);
|
174 |
+
p.view_pos.store_grad(px, py, pz, d__view_vec);
|
175 |
+
p.perturbed_nrm.store_grad(px, py, pz, d_perturbed_nrm);
|
176 |
+
p.smooth_nrm.store_grad(px, py, pz, d__smooth_nrm);
|
177 |
+
p.smooth_tng.store_grad(px, py, pz, d__smooth_tng);
|
178 |
+
p.geom_nrm.store_grad(px, py, pz, d_geom_nrm);
|
179 |
+
}
|
nvdiffmodeling/src/renderutils/c_src/normal.h
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#pragma once
|
10 |
+
|
11 |
+
#include "common.h"
|
12 |
+
|
13 |
+
struct PrepareShadingNormalKernelParams
|
14 |
+
{
|
15 |
+
Tensor pos;
|
16 |
+
Tensor view_pos;
|
17 |
+
Tensor perturbed_nrm;
|
18 |
+
Tensor smooth_nrm;
|
19 |
+
Tensor smooth_tng;
|
20 |
+
Tensor geom_nrm;
|
21 |
+
Tensor out;
|
22 |
+
dim3 gridSize;
|
23 |
+
bool two_sided_shading, opengl;
|
24 |
+
};
|
nvdiffmodeling/src/renderutils/c_src/tensor.h
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#pragma once
|
10 |
+
#if defined(__CUDACC__) && defined(BFLOAT16)
|
11 |
+
#include <cuda_bf16.h> // bfloat16 is float32 compatible with less mantissa bits
|
12 |
+
#endif
|
13 |
+
|
14 |
+
//---------------------------------------------------------------------------------
|
15 |
+
// CUDA-side Tensor class for in/out parameter parsing. Can be float32 or bfloat16
|
16 |
+
|
17 |
+
struct Tensor
|
18 |
+
{
|
19 |
+
void* val;
|
20 |
+
void* d_val;
|
21 |
+
int dims[4];
|
22 |
+
int strides[4];
|
23 |
+
bool fp16;
|
24 |
+
Tensor() : val(nullptr), d_val(nullptr), fp16(true), dims{ 0, 0, 0, 0 }, strides{ 0, 0, 0, 0 } {}
|
25 |
+
|
26 |
+
#ifdef __CUDACC__
|
27 |
+
// Helpers to index and read/write a single element
|
28 |
+
__device__ inline int _nhwcIndex(int n, int h, int w, int c) const { return n * strides[0] + h * strides[1] + w * strides[2] + c * strides[3]; }
|
29 |
+
__device__ inline int nhwcIndex(int n, int h, int w, int c) const { return (dims[0] == 1 ? 0 : n * strides[0]) + (dims[1] == 1 ? 0 : h * strides[1]) + (dims[2] == 1 ? 0 : w * strides[2]) + (dims[3] == 1 ? 0 : c * strides[3]); }
|
30 |
+
__device__ inline int nhwcIndexContinuous(int n, int h, int w, int c) const { return ((n * dims[1] + h) * dims[2] + w) * dims[3] + c; }
|
31 |
+
#ifdef BFLOAT16
|
32 |
+
__device__ inline float fetch(unsigned int idx) const { return fp16 ? __bfloat162float(((__nv_bfloat16*)val)[idx]) : ((float*)val)[idx]; }
|
33 |
+
__device__ inline void store(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)val)[idx] = __float2bfloat16(_val); else ((float*)val)[idx] = _val; }
|
34 |
+
__device__ inline void store_grad(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)d_val)[idx] = __float2bfloat16(_val); else ((float*)d_val)[idx] = _val; }
|
35 |
+
#else
|
36 |
+
__device__ inline float fetch(unsigned int idx) const { return ((float*)val)[idx]; }
|
37 |
+
__device__ inline void store(unsigned int idx, float _val) { ((float*)val)[idx] = _val; }
|
38 |
+
__device__ inline void store_grad(unsigned int idx, float _val) { ((float*)d_val)[idx] = _val; }
|
39 |
+
#endif
|
40 |
+
|
41 |
+
//////////////////////////////////////////////////////////////////////////////////////////
|
42 |
+
// Fetch, use broadcasting for tensor dimensions of size 1
|
43 |
+
__device__ inline float fetch1(unsigned int x, unsigned int y, unsigned int z) const
|
44 |
+
{
|
45 |
+
return fetch(nhwcIndex(z, y, x, 0));
|
46 |
+
}
|
47 |
+
|
48 |
+
__device__ inline vec3f fetch3(unsigned int x, unsigned int y, unsigned int z) const
|
49 |
+
{
|
50 |
+
return vec3f(
|
51 |
+
fetch(nhwcIndex(z, y, x, 0)),
|
52 |
+
fetch(nhwcIndex(z, y, x, 1)),
|
53 |
+
fetch(nhwcIndex(z, y, x, 2))
|
54 |
+
);
|
55 |
+
}
|
56 |
+
|
57 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
58 |
+
// Store, no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside
|
59 |
+
__device__ inline void store(unsigned int x, unsigned int y, unsigned int z, float _val)
|
60 |
+
{
|
61 |
+
store(_nhwcIndex(z, y, x, 0), _val);
|
62 |
+
}
|
63 |
+
|
64 |
+
__device__ inline void store(unsigned int x, unsigned int y, unsigned int z, vec3f _val)
|
65 |
+
{
|
66 |
+
store(_nhwcIndex(z, y, x, 0), _val.x);
|
67 |
+
store(_nhwcIndex(z, y, x, 1), _val.y);
|
68 |
+
store(_nhwcIndex(z, y, x, 2), _val.z);
|
69 |
+
}
|
70 |
+
|
71 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
72 |
+
// Store gradient , no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside
|
73 |
+
__device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, float _val)
|
74 |
+
{
|
75 |
+
store_grad(nhwcIndexContinuous(z, y, x, 0), _val);
|
76 |
+
}
|
77 |
+
|
78 |
+
__device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, vec3f _val)
|
79 |
+
{
|
80 |
+
store_grad(nhwcIndexContinuous(z, y, x, 0), _val.x);
|
81 |
+
store_grad(nhwcIndexContinuous(z, y, x, 1), _val.y);
|
82 |
+
store_grad(nhwcIndexContinuous(z, y, x, 2), _val.z);
|
83 |
+
}
|
84 |
+
#endif
|
85 |
+
|
86 |
+
};
|
nvdiffmodeling/src/renderutils/c_src/torch_bindings.cpp
ADDED
@@ -0,0 +1,793 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <torch/extension.h>
|
10 |
+
#include <ATen/cuda/CUDAContext.h>
|
11 |
+
#include <ATen/cuda/CUDAUtils.h>
|
12 |
+
#include <algorithm>
|
13 |
+
#include <string>
|
14 |
+
|
15 |
+
#define NVDR_CHECK_CUDA_ERROR(CUDA_CALL) { cudaError_t err = CUDA_CALL; AT_CUDA_CHECK(cudaGetLastError()); }
|
16 |
+
#define NVDR_CHECK_GL_ERROR(GL_CALL) { GL_CALL; GLenum err = glGetError(); TORCH_CHECK(err == GL_NO_ERROR, "OpenGL error: ", getGLErrorString(err), "[", #GL_CALL, ";]"); }
|
17 |
+
#define CHECK_TENSOR(X, DIMS, CHANNELS) \
|
18 |
+
TORCH_CHECK(X.is_cuda(), #X " must be a cuda tensor") \
|
19 |
+
TORCH_CHECK(X.scalar_type() == torch::kFloat || X.scalar_type() == torch::kBFloat16, #X " must be fp32 or bf16") \
|
20 |
+
TORCH_CHECK(X.dim() == DIMS, #X " must have " #DIMS " dimensions") \
|
21 |
+
TORCH_CHECK(X.size(DIMS - 1) == CHANNELS, #X " must have " #CHANNELS " channels")
|
22 |
+
|
23 |
+
#include "common.h"
|
24 |
+
#include "loss.h"
|
25 |
+
#include "normal.h"
|
26 |
+
#include "bsdf.h"
|
27 |
+
#include "mesh.h"
|
28 |
+
|
29 |
+
#define BLOCK_X 8
|
30 |
+
#define BLOCK_Y 8
|
31 |
+
|
32 |
+
//------------------------------------------------------------------------
|
33 |
+
// mesh.cu
|
34 |
+
|
35 |
+
void xfmPointsFwdKernel(XfmKernelParams p);
|
36 |
+
void xfmPointsBwdKernel(XfmKernelParams p);
|
37 |
+
|
38 |
+
//------------------------------------------------------------------------
|
39 |
+
// loss.cu
|
40 |
+
|
41 |
+
void imgLossFwdKernel(LossKernelParams p);
|
42 |
+
void imgLossBwdKernel(LossKernelParams p);
|
43 |
+
|
44 |
+
//------------------------------------------------------------------------
|
45 |
+
// normal.cu
|
46 |
+
|
47 |
+
void PrepareShadingNormalFwdKernel(PrepareShadingNormalKernelParams p);
|
48 |
+
void PrepareShadingNormalBwdKernel(PrepareShadingNormalKernelParams p);
|
49 |
+
|
50 |
+
//------------------------------------------------------------------------
|
51 |
+
// bsdf.cu
|
52 |
+
|
53 |
+
void LambertFwdKernel(LambertKernelParams p);
|
54 |
+
void LambertBwdKernel(LambertKernelParams p);
|
55 |
+
|
56 |
+
void FresnelShlickFwdKernel(FresnelShlickKernelParams p);
|
57 |
+
void FresnelShlickBwdKernel(FresnelShlickKernelParams p);
|
58 |
+
|
59 |
+
void ndfGGXFwdKernel(NdfGGXParams p);
|
60 |
+
void ndfGGXBwdKernel(NdfGGXParams p);
|
61 |
+
|
62 |
+
void lambdaGGXFwdKernel(NdfGGXParams p);
|
63 |
+
void lambdaGGXBwdKernel(NdfGGXParams p);
|
64 |
+
|
65 |
+
void maskingSmithFwdKernel(MaskingSmithParams p);
|
66 |
+
void maskingSmithBwdKernel(MaskingSmithParams p);
|
67 |
+
|
68 |
+
void pbrSpecularFwdKernel(PbrSpecular p);
|
69 |
+
void pbrSpecularBwdKernel(PbrSpecular p);
|
70 |
+
|
71 |
+
void pbrBSDFFwdKernel(PbrBSDF p);
|
72 |
+
void pbrBSDFBwdKernel(PbrBSDF p);
|
73 |
+
|
74 |
+
//------------------------------------------------------------------------
|
75 |
+
// Tensor helpers
|
76 |
+
|
77 |
+
void update_grid(dim3 &gridSize, torch::Tensor x)
|
78 |
+
{
|
79 |
+
gridSize.x = std::max(gridSize.x, (uint32_t)x.size(2));
|
80 |
+
gridSize.y = std::max(gridSize.y, (uint32_t)x.size(1));
|
81 |
+
gridSize.z = std::max(gridSize.z, (uint32_t)x.size(0));
|
82 |
+
}
|
83 |
+
|
84 |
+
template<typename... Ts>
|
85 |
+
void update_grid(dim3& gridSize, torch::Tensor x, Ts&&... vs)
|
86 |
+
{
|
87 |
+
gridSize.x = std::max(gridSize.x, (uint32_t)x.size(2));
|
88 |
+
gridSize.y = std::max(gridSize.y, (uint32_t)x.size(1));
|
89 |
+
gridSize.z = std::max(gridSize.z, (uint32_t)x.size(0));
|
90 |
+
update_grid(gridSize, std::forward<Ts>(vs)...);
|
91 |
+
}
|
92 |
+
|
93 |
+
Tensor make_cuda_tensor(torch::Tensor val, dim3 outDims, torch::Tensor* grad = nullptr)
|
94 |
+
{
|
95 |
+
Tensor res;
|
96 |
+
for (int i = 0; i < val.dim(); ++i)
|
97 |
+
{
|
98 |
+
res.dims[i] = val.size(i);
|
99 |
+
res.strides[i] = val.stride(i);
|
100 |
+
}
|
101 |
+
|
102 |
+
res.fp16 = val.scalar_type() == torch::kBFloat16;
|
103 |
+
res.val = res.fp16 ? (void*)val.data_ptr<torch::BFloat16>() : (void*)val.data_ptr<float>();
|
104 |
+
res.d_val = nullptr;
|
105 |
+
if (grad != nullptr)
|
106 |
+
{
|
107 |
+
if (val.dim() == 4)
|
108 |
+
*grad = torch::empty({ outDims.z, outDims.y, outDims.x, val.size(3) }, torch::TensorOptions().dtype(res.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA));
|
109 |
+
else // 3
|
110 |
+
*grad = torch::empty({ outDims.z, outDims.x, val.size(2) }, torch::TensorOptions().dtype(res.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA));
|
111 |
+
|
112 |
+
res.d_val = res.fp16 ? (void*)grad->data_ptr<torch::BFloat16>() : (void*)grad->data_ptr<float>();
|
113 |
+
}
|
114 |
+
return res;
|
115 |
+
}
|
116 |
+
|
117 |
+
//------------------------------------------------------------------------
|
118 |
+
// prepare_shading_normal
|
119 |
+
|
120 |
+
torch::Tensor prepare_shading_normal_fwd(torch::Tensor pos, torch::Tensor view_pos, torch::Tensor perturbed_nrm, torch::Tensor smooth_nrm, torch::Tensor smooth_tng, torch::Tensor geom_nrm, bool two_sided_shading, bool opengl, bool fp16)
|
121 |
+
{
|
122 |
+
CHECK_TENSOR(pos, 4, 3);
|
123 |
+
CHECK_TENSOR(view_pos, 4, 3);
|
124 |
+
CHECK_TENSOR(perturbed_nrm, 4, 3);
|
125 |
+
CHECK_TENSOR(smooth_nrm, 4, 3);
|
126 |
+
CHECK_TENSOR(smooth_tng, 4, 3);
|
127 |
+
CHECK_TENSOR(geom_nrm, 4, 3);
|
128 |
+
|
129 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
130 |
+
|
131 |
+
// Extract input parameters.
|
132 |
+
PrepareShadingNormalKernelParams p;
|
133 |
+
p.two_sided_shading = two_sided_shading;
|
134 |
+
p.opengl = opengl;
|
135 |
+
p.out.fp16 = fp16;
|
136 |
+
update_grid(p.gridSize, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm);
|
137 |
+
|
138 |
+
// Allocate output tensors.
|
139 |
+
torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
|
140 |
+
torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts);
|
141 |
+
|
142 |
+
// Choose launch parameters.
|
143 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
144 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
145 |
+
|
146 |
+
// Setup tensors
|
147 |
+
p.pos = make_cuda_tensor(pos, p.gridSize);
|
148 |
+
p.view_pos = make_cuda_tensor(view_pos, p.gridSize);
|
149 |
+
p.perturbed_nrm = make_cuda_tensor(perturbed_nrm, p.gridSize);
|
150 |
+
p.smooth_nrm = make_cuda_tensor(smooth_nrm, p.gridSize);
|
151 |
+
p.smooth_tng = make_cuda_tensor(smooth_tng, p.gridSize);
|
152 |
+
p.geom_nrm = make_cuda_tensor(geom_nrm, p.gridSize);
|
153 |
+
p.out = make_cuda_tensor(out, p.gridSize);
|
154 |
+
|
155 |
+
// Launch CUDA kernel.
|
156 |
+
void* args[] = { &p };
|
157 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)PrepareShadingNormalFwdKernel, gridSize, blockSize, args, 0, stream));
|
158 |
+
|
159 |
+
return out;
|
160 |
+
}
|
161 |
+
|
162 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> prepare_shading_normal_bwd(torch::Tensor pos, torch::Tensor view_pos, torch::Tensor perturbed_nrm, torch::Tensor smooth_nrm, torch::Tensor smooth_tng, torch::Tensor geom_nrm, torch::Tensor grad, bool two_sided_shading, bool opengl)
|
163 |
+
{
|
164 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
165 |
+
|
166 |
+
// Extract input parameters.
|
167 |
+
PrepareShadingNormalKernelParams p;
|
168 |
+
p.two_sided_shading = two_sided_shading;
|
169 |
+
p.opengl = opengl;
|
170 |
+
update_grid(p.gridSize, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm);
|
171 |
+
|
172 |
+
// Choose launch parameters.
|
173 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
174 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
175 |
+
|
176 |
+
// Setup tensors
|
177 |
+
torch::Tensor pos_grad, view_pos_grad, perturbed_nrm_grad, smooth_nrm_grad, smooth_tng_grad, geom_nrm_grad;
|
178 |
+
p.pos = make_cuda_tensor(pos, p.gridSize, &pos_grad);
|
179 |
+
p.view_pos = make_cuda_tensor(view_pos, p.gridSize, &view_pos_grad);
|
180 |
+
p.perturbed_nrm = make_cuda_tensor(perturbed_nrm, p.gridSize, &perturbed_nrm_grad);
|
181 |
+
p.smooth_nrm = make_cuda_tensor(smooth_nrm, p.gridSize, &smooth_nrm_grad);
|
182 |
+
p.smooth_tng = make_cuda_tensor(smooth_tng, p.gridSize, &smooth_tng_grad);
|
183 |
+
p.geom_nrm = make_cuda_tensor(geom_nrm, p.gridSize, &geom_nrm_grad);
|
184 |
+
p.out = make_cuda_tensor(grad, p.gridSize);
|
185 |
+
|
186 |
+
// Launch CUDA kernel.
|
187 |
+
void* args[] = { &p };
|
188 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)PrepareShadingNormalBwdKernel, gridSize, blockSize, args, 0, stream));
|
189 |
+
|
190 |
+
return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(pos_grad, view_pos_grad, perturbed_nrm_grad, smooth_nrm_grad, smooth_tng_grad, geom_nrm_grad);
|
191 |
+
}
|
192 |
+
|
193 |
+
//------------------------------------------------------------------------
|
194 |
+
// lambert
|
195 |
+
|
196 |
+
torch::Tensor lambert_fwd(torch::Tensor nrm, torch::Tensor wi, bool fp16)
|
197 |
+
{
|
198 |
+
CHECK_TENSOR(nrm, 4, 3);
|
199 |
+
CHECK_TENSOR(wi, 4, 3);
|
200 |
+
|
201 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
202 |
+
|
203 |
+
// Extract input parameters.
|
204 |
+
LambertKernelParams p;
|
205 |
+
p.out.fp16 = fp16;
|
206 |
+
update_grid(p.gridSize, nrm, wi);
|
207 |
+
|
208 |
+
// Allocate output tensors.
|
209 |
+
torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
|
210 |
+
torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts);
|
211 |
+
|
212 |
+
// Choose launch parameters.
|
213 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
214 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
215 |
+
|
216 |
+
p.nrm = make_cuda_tensor(nrm, p.gridSize);
|
217 |
+
p.wi = make_cuda_tensor(wi, p.gridSize);
|
218 |
+
p.out = make_cuda_tensor(out, p.gridSize);
|
219 |
+
|
220 |
+
// Launch CUDA kernel.
|
221 |
+
void* args[] = { &p };
|
222 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)LambertFwdKernel, gridSize, blockSize, args, 0, stream));
|
223 |
+
|
224 |
+
return out;
|
225 |
+
}
|
226 |
+
|
227 |
+
std::tuple<torch::Tensor, torch::Tensor> lambert_bwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor grad)
|
228 |
+
{
|
229 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
230 |
+
|
231 |
+
// Extract input parameters.
|
232 |
+
LambertKernelParams p;
|
233 |
+
update_grid(p.gridSize, nrm, wi);
|
234 |
+
|
235 |
+
// Choose launch parameters.
|
236 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
237 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
238 |
+
|
239 |
+
torch::Tensor nrm_grad, wi_grad;
|
240 |
+
p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad);
|
241 |
+
p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad);
|
242 |
+
p.out = make_cuda_tensor(grad, p.gridSize);
|
243 |
+
|
244 |
+
// Launch CUDA kernel.
|
245 |
+
void* args[] = { &p };
|
246 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)LambertBwdKernel, gridSize, blockSize, args, 0, stream));
|
247 |
+
|
248 |
+
return std::tuple<torch::Tensor, torch::Tensor>(nrm_grad, wi_grad);
|
249 |
+
}
|
250 |
+
|
251 |
+
//------------------------------------------------------------------------
|
252 |
+
// fresnel_shlick
|
253 |
+
|
254 |
+
torch::Tensor fresnel_shlick_fwd(torch::Tensor f0, torch::Tensor f90, torch::Tensor cosTheta, bool fp16)
|
255 |
+
{
|
256 |
+
CHECK_TENSOR(f0, 4, 3);
|
257 |
+
CHECK_TENSOR(f90, 4, 3);
|
258 |
+
CHECK_TENSOR(cosTheta, 4, 1);
|
259 |
+
|
260 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
261 |
+
|
262 |
+
// Extract input parameters.
|
263 |
+
FresnelShlickKernelParams p;
|
264 |
+
p.out.fp16 = fp16;
|
265 |
+
update_grid(p.gridSize, f0, f90, cosTheta);
|
266 |
+
|
267 |
+
// Allocate output tensors.
|
268 |
+
torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
|
269 |
+
torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts);
|
270 |
+
|
271 |
+
// Choose launch parameters.
|
272 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
273 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
274 |
+
|
275 |
+
p.f0 = make_cuda_tensor(f0, p.gridSize);
|
276 |
+
p.f90 = make_cuda_tensor(f90, p.gridSize);
|
277 |
+
p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize);
|
278 |
+
p.out = make_cuda_tensor(out, p.gridSize);
|
279 |
+
|
280 |
+
// Launch CUDA kernel.
|
281 |
+
void* args[] = { &p };
|
282 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FresnelShlickFwdKernel, gridSize, blockSize, args, 0, stream));
|
283 |
+
|
284 |
+
return out;
|
285 |
+
}
|
286 |
+
|
287 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> fresnel_shlick_bwd(torch::Tensor f0, torch::Tensor f90, torch::Tensor cosTheta, torch::Tensor grad)
|
288 |
+
{
|
289 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
290 |
+
|
291 |
+
// Extract input parameters.
|
292 |
+
FresnelShlickKernelParams p;
|
293 |
+
update_grid(p.gridSize, f0, f90, cosTheta);
|
294 |
+
|
295 |
+
// Choose launch parameters.
|
296 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
297 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
298 |
+
|
299 |
+
torch::Tensor f0_grad, f90_grad, cosT_grad;
|
300 |
+
p.f0 = make_cuda_tensor(f0, p.gridSize, &f0_grad);
|
301 |
+
p.f90 = make_cuda_tensor(f90, p.gridSize, &f90_grad);
|
302 |
+
p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosT_grad);
|
303 |
+
p.out = make_cuda_tensor(grad, p.gridSize);
|
304 |
+
|
305 |
+
// Launch CUDA kernel.
|
306 |
+
void* args[] = { &p };
|
307 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FresnelShlickBwdKernel, gridSize, blockSize, args, 0, stream));
|
308 |
+
|
309 |
+
return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>(f0_grad, f90_grad, cosT_grad);
|
310 |
+
}
|
311 |
+
|
312 |
+
//------------------------------------------------------------------------
|
313 |
+
// ndf_ggd
|
314 |
+
|
315 |
+
torch::Tensor ndf_ggx_fwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, bool fp16)
|
316 |
+
{
|
317 |
+
CHECK_TENSOR(alphaSqr, 4, 1);
|
318 |
+
CHECK_TENSOR(cosTheta, 4, 1);
|
319 |
+
|
320 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
321 |
+
|
322 |
+
// Extract input parameters.
|
323 |
+
NdfGGXParams p;
|
324 |
+
p.out.fp16 = fp16;
|
325 |
+
update_grid(p.gridSize, alphaSqr, cosTheta);
|
326 |
+
|
327 |
+
// Allocate output tensors.
|
328 |
+
torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
|
329 |
+
torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts);
|
330 |
+
|
331 |
+
// Choose launch parameters.
|
332 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
333 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
334 |
+
|
335 |
+
p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize);
|
336 |
+
p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize);
|
337 |
+
p.out = make_cuda_tensor(out, p.gridSize);
|
338 |
+
|
339 |
+
// Launch CUDA kernel.
|
340 |
+
void* args[] = { &p };
|
341 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)ndfGGXFwdKernel, gridSize, blockSize, args, 0, stream));
|
342 |
+
|
343 |
+
return out;
|
344 |
+
}
|
345 |
+
|
346 |
+
std::tuple<torch::Tensor, torch::Tensor> ndf_ggx_bwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, torch::Tensor grad)
|
347 |
+
{
|
348 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
349 |
+
|
350 |
+
// Extract input parameters.
|
351 |
+
NdfGGXParams p;
|
352 |
+
update_grid(p.gridSize, alphaSqr, cosTheta);
|
353 |
+
|
354 |
+
// Choose launch parameters.
|
355 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
356 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
357 |
+
|
358 |
+
torch::Tensor alphaSqr_grad, cosTheta_grad;
|
359 |
+
p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad);
|
360 |
+
p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosTheta_grad);
|
361 |
+
p.out = make_cuda_tensor(grad, p.gridSize);
|
362 |
+
|
363 |
+
// Launch CUDA kernel.
|
364 |
+
void* args[] = { &p };
|
365 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)ndfGGXBwdKernel, gridSize, blockSize, args, 0, stream));
|
366 |
+
|
367 |
+
return std::tuple<torch::Tensor, torch::Tensor>(alphaSqr_grad, cosTheta_grad);
|
368 |
+
}
|
369 |
+
|
370 |
+
//------------------------------------------------------------------------
|
371 |
+
// lambda_ggx
|
372 |
+
|
373 |
+
torch::Tensor lambda_ggx_fwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, bool fp16)
|
374 |
+
{
|
375 |
+
CHECK_TENSOR(alphaSqr, 4, 1);
|
376 |
+
CHECK_TENSOR(cosTheta, 4, 1);
|
377 |
+
|
378 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
379 |
+
|
380 |
+
// Extract input parameters.
|
381 |
+
NdfGGXParams p;
|
382 |
+
p.out.fp16 = fp16;
|
383 |
+
update_grid(p.gridSize, alphaSqr, cosTheta);
|
384 |
+
|
385 |
+
// Allocate output tensors.
|
386 |
+
torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
|
387 |
+
torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts);
|
388 |
+
|
389 |
+
// Choose launch parameters.
|
390 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
391 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
392 |
+
|
393 |
+
p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize);
|
394 |
+
p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize);
|
395 |
+
p.out = make_cuda_tensor(out, p.gridSize);
|
396 |
+
|
397 |
+
// Launch CUDA kernel.
|
398 |
+
void* args[] = { &p };
|
399 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)lambdaGGXFwdKernel, gridSize, blockSize, args, 0, stream));
|
400 |
+
|
401 |
+
return out;
|
402 |
+
}
|
403 |
+
|
404 |
+
std::tuple<torch::Tensor, torch::Tensor> lambda_ggx_bwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, torch::Tensor grad)
|
405 |
+
{
|
406 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
407 |
+
|
408 |
+
// Extract input parameters.
|
409 |
+
NdfGGXParams p;
|
410 |
+
update_grid(p.gridSize, alphaSqr, cosTheta);
|
411 |
+
|
412 |
+
// Choose launch parameters.
|
413 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
414 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
415 |
+
|
416 |
+
torch::Tensor alphaSqr_grad, cosTheta_grad;
|
417 |
+
p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad);
|
418 |
+
p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosTheta_grad);
|
419 |
+
p.out = make_cuda_tensor(grad, p.gridSize);
|
420 |
+
|
421 |
+
// Launch CUDA kernel.
|
422 |
+
void* args[] = { &p };
|
423 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)lambdaGGXBwdKernel, gridSize, blockSize, args, 0, stream));
|
424 |
+
|
425 |
+
return std::tuple<torch::Tensor, torch::Tensor>(alphaSqr_grad, cosTheta_grad);
|
426 |
+
}
|
427 |
+
|
428 |
+
//------------------------------------------------------------------------
|
429 |
+
// masking_smith
|
430 |
+
|
431 |
+
torch::Tensor masking_smith_fwd(torch::Tensor alphaSqr, torch::Tensor cosThetaI, torch::Tensor cosThetaO, bool fp16)
|
432 |
+
{
|
433 |
+
CHECK_TENSOR(alphaSqr, 4, 1);
|
434 |
+
CHECK_TENSOR(cosThetaI, 4, 1);
|
435 |
+
CHECK_TENSOR(cosThetaO, 4, 1);
|
436 |
+
|
437 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
438 |
+
|
439 |
+
// Extract input parameters.
|
440 |
+
MaskingSmithParams p;
|
441 |
+
p.out.fp16 = fp16;
|
442 |
+
update_grid(p.gridSize, alphaSqr, cosThetaI, cosThetaO);
|
443 |
+
|
444 |
+
// Allocate output tensors.
|
445 |
+
torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
|
446 |
+
torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts);
|
447 |
+
|
448 |
+
// Choose launch parameters.
|
449 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
450 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
451 |
+
|
452 |
+
p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize);
|
453 |
+
p.cosThetaI = make_cuda_tensor(cosThetaI, p.gridSize);
|
454 |
+
p.cosThetaO = make_cuda_tensor(cosThetaO, p.gridSize);
|
455 |
+
p.out = make_cuda_tensor(out, p.gridSize);
|
456 |
+
|
457 |
+
// Launch CUDA kernel.
|
458 |
+
void* args[] = { &p };
|
459 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)maskingSmithFwdKernel, gridSize, blockSize, args, 0, stream));
|
460 |
+
|
461 |
+
return out;
|
462 |
+
}
|
463 |
+
|
464 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> masking_smith_bwd(torch::Tensor alphaSqr, torch::Tensor cosThetaI, torch::Tensor cosThetaO, torch::Tensor grad)
|
465 |
+
{
|
466 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
467 |
+
|
468 |
+
// Extract input parameters.
|
469 |
+
MaskingSmithParams p;
|
470 |
+
update_grid(p.gridSize, alphaSqr, cosThetaI, cosThetaO);
|
471 |
+
|
472 |
+
// Choose launch parameters.
|
473 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
474 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
475 |
+
|
476 |
+
torch::Tensor alphaSqr_grad, cosThetaI_grad, cosThetaO_grad;
|
477 |
+
p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad);
|
478 |
+
p.cosThetaI = make_cuda_tensor(cosThetaI, p.gridSize, &cosThetaI_grad);
|
479 |
+
p.cosThetaO = make_cuda_tensor(cosThetaO, p.gridSize, &cosThetaO_grad);
|
480 |
+
p.out = make_cuda_tensor(grad, p.gridSize);
|
481 |
+
|
482 |
+
// Launch CUDA kernel.
|
483 |
+
void* args[] = { &p };
|
484 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)maskingSmithBwdKernel, gridSize, blockSize, args, 0, stream));
|
485 |
+
|
486 |
+
return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>(alphaSqr_grad, cosThetaI_grad, cosThetaO_grad);
|
487 |
+
}
|
488 |
+
|
489 |
+
//------------------------------------------------------------------------
|
490 |
+
// pbr_specular
|
491 |
+
|
492 |
+
torch::Tensor pbr_specular_fwd(torch::Tensor col, torch::Tensor nrm, torch::Tensor wo, torch::Tensor wi, torch::Tensor alpha, float min_roughness, bool fp16)
|
493 |
+
{
|
494 |
+
CHECK_TENSOR(col, 4, 3);
|
495 |
+
CHECK_TENSOR(nrm, 4, 3);
|
496 |
+
CHECK_TENSOR(wo, 4, 3);
|
497 |
+
CHECK_TENSOR(wi, 4, 3);
|
498 |
+
CHECK_TENSOR(alpha, 4, 1);
|
499 |
+
|
500 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
501 |
+
|
502 |
+
// Extract input parameters.
|
503 |
+
PbrSpecular p;
|
504 |
+
p.out.fp16 = fp16;
|
505 |
+
p.min_roughness = min_roughness;
|
506 |
+
update_grid(p.gridSize, col, nrm, wo, wi, alpha);
|
507 |
+
|
508 |
+
// Allocate output tensors.
|
509 |
+
torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
|
510 |
+
torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts);
|
511 |
+
|
512 |
+
// Choose launch parameters.
|
513 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
514 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
515 |
+
|
516 |
+
p.col = make_cuda_tensor(col, p.gridSize);
|
517 |
+
p.nrm = make_cuda_tensor(nrm, p.gridSize);
|
518 |
+
p.wo = make_cuda_tensor(wo, p.gridSize);
|
519 |
+
p.wi = make_cuda_tensor(wi, p.gridSize);
|
520 |
+
p.alpha = make_cuda_tensor(alpha, p.gridSize);
|
521 |
+
p.out = make_cuda_tensor(out, p.gridSize);
|
522 |
+
|
523 |
+
// Launch CUDA kernel.
|
524 |
+
void* args[] = { &p };
|
525 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrSpecularFwdKernel, gridSize, blockSize, args, 0, stream));
|
526 |
+
|
527 |
+
return out;
|
528 |
+
}
|
529 |
+
|
530 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> pbr_specular_bwd(torch::Tensor col, torch::Tensor nrm, torch::Tensor wo, torch::Tensor wi, torch::Tensor alpha, float min_roughness, torch::Tensor grad)
|
531 |
+
{
|
532 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
533 |
+
|
534 |
+
// Extract input parameters.
|
535 |
+
PbrSpecular p;
|
536 |
+
update_grid(p.gridSize, col, nrm, wo, wi, alpha);
|
537 |
+
p.min_roughness = min_roughness;
|
538 |
+
|
539 |
+
// Choose launch parameters.
|
540 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
541 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
542 |
+
|
543 |
+
torch::Tensor col_grad, nrm_grad, wo_grad, wi_grad, alpha_grad;
|
544 |
+
p.col = make_cuda_tensor(col, p.gridSize, &col_grad);
|
545 |
+
p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad);
|
546 |
+
p.wo = make_cuda_tensor(wo, p.gridSize, &wo_grad);
|
547 |
+
p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad);
|
548 |
+
p.alpha = make_cuda_tensor(alpha, p.gridSize, &alpha_grad);
|
549 |
+
p.out = make_cuda_tensor(grad, p.gridSize);
|
550 |
+
|
551 |
+
// Launch CUDA kernel.
|
552 |
+
void* args[] = { &p };
|
553 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrSpecularBwdKernel, gridSize, blockSize, args, 0, stream));
|
554 |
+
|
555 |
+
return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(col_grad, nrm_grad, wo_grad, wi_grad, alpha_grad);
|
556 |
+
}
|
557 |
+
|
558 |
+
//------------------------------------------------------------------------
|
559 |
+
// pbr_bsdf
|
560 |
+
|
561 |
+
torch::Tensor pbr_bsdf_fwd(torch::Tensor kd, torch::Tensor arm, torch::Tensor pos, torch::Tensor nrm, torch::Tensor view_pos, torch::Tensor light_pos, float min_roughness, bool fp16)
|
562 |
+
{
|
563 |
+
CHECK_TENSOR(kd, 4, 3);
|
564 |
+
CHECK_TENSOR(arm, 4, 3);
|
565 |
+
CHECK_TENSOR(pos, 4, 3);
|
566 |
+
CHECK_TENSOR(nrm, 4, 3);
|
567 |
+
CHECK_TENSOR(view_pos, 4, 3);
|
568 |
+
CHECK_TENSOR(light_pos, 4, 3);
|
569 |
+
|
570 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
571 |
+
|
572 |
+
// Extract input parameters.
|
573 |
+
PbrBSDF p;
|
574 |
+
p.out.fp16 = fp16;
|
575 |
+
p.min_roughness = min_roughness;
|
576 |
+
update_grid(p.gridSize, kd, arm, pos, nrm, view_pos, light_pos);
|
577 |
+
|
578 |
+
// Allocate output tensors.
|
579 |
+
torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
|
580 |
+
torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts);
|
581 |
+
|
582 |
+
// Choose launch parameters.
|
583 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
584 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
585 |
+
|
586 |
+
p.kd = make_cuda_tensor(kd, p.gridSize);
|
587 |
+
p.arm = make_cuda_tensor(arm, p.gridSize);
|
588 |
+
p.pos = make_cuda_tensor(pos, p.gridSize);
|
589 |
+
p.nrm = make_cuda_tensor(nrm, p.gridSize);
|
590 |
+
p.view_pos = make_cuda_tensor(view_pos, p.gridSize);
|
591 |
+
p.light_pos = make_cuda_tensor(light_pos, p.gridSize);
|
592 |
+
p.out = make_cuda_tensor(out, p.gridSize);
|
593 |
+
|
594 |
+
// Launch CUDA kernel.
|
595 |
+
void* args[] = { &p };
|
596 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrBSDFFwdKernel, gridSize, blockSize, args, 0, stream));
|
597 |
+
|
598 |
+
return out;
|
599 |
+
}
|
600 |
+
|
601 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> pbr_bsdf_bwd(torch::Tensor kd, torch::Tensor arm, torch::Tensor pos, torch::Tensor nrm, torch::Tensor view_pos, torch::Tensor light_pos, float min_roughness, torch::Tensor grad)
|
602 |
+
{
|
603 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
604 |
+
|
605 |
+
// Extract input parameters.
|
606 |
+
PbrBSDF p;
|
607 |
+
update_grid(p.gridSize, kd, arm, pos, nrm, view_pos, light_pos);
|
608 |
+
p.min_roughness = min_roughness;
|
609 |
+
|
610 |
+
// Choose launch parameters.
|
611 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
612 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
613 |
+
|
614 |
+
torch::Tensor kd_grad, arm_grad, pos_grad, nrm_grad, view_pos_grad, light_pos_grad;
|
615 |
+
p.kd = make_cuda_tensor(kd, p.gridSize, &kd_grad);
|
616 |
+
p.arm = make_cuda_tensor(arm, p.gridSize, &arm_grad);
|
617 |
+
p.pos = make_cuda_tensor(pos, p.gridSize, &pos_grad);
|
618 |
+
p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad);
|
619 |
+
p.view_pos = make_cuda_tensor(view_pos, p.gridSize, &view_pos_grad);
|
620 |
+
p.light_pos = make_cuda_tensor(light_pos, p.gridSize, &light_pos_grad);
|
621 |
+
p.out = make_cuda_tensor(grad, p.gridSize);
|
622 |
+
|
623 |
+
// Launch CUDA kernel.
|
624 |
+
void* args[] = { &p };
|
625 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrBSDFBwdKernel, gridSize, blockSize, args, 0, stream));
|
626 |
+
|
627 |
+
return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(kd_grad, arm_grad, pos_grad, nrm_grad, view_pos_grad, light_pos_grad);
|
628 |
+
}
|
629 |
+
|
630 |
+
//------------------------------------------------------------------------
|
631 |
+
// loss function
|
632 |
+
|
633 |
+
LossType strToLoss(std::string str)
|
634 |
+
{
|
635 |
+
if (str == "mse")
|
636 |
+
return LOSS_MSE;
|
637 |
+
else if (str == "relmse")
|
638 |
+
return LOSS_RELMSE;
|
639 |
+
else if (str == "smape")
|
640 |
+
return LOSS_SMAPE;
|
641 |
+
else
|
642 |
+
return LOSS_L1;
|
643 |
+
}
|
644 |
+
|
645 |
+
torch::Tensor image_loss_fwd(torch::Tensor img, torch::Tensor target, std::string loss, std::string tonemapper, bool fp16)
|
646 |
+
{
|
647 |
+
CHECK_TENSOR(img, 4, 3);
|
648 |
+
CHECK_TENSOR(target, 4, 3);
|
649 |
+
|
650 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
651 |
+
|
652 |
+
// Extract input parameters.
|
653 |
+
LossKernelParams p;
|
654 |
+
p.out.fp16 = fp16;
|
655 |
+
p.loss = strToLoss(loss);
|
656 |
+
p.tonemapper = tonemapper == "log_srgb" ? TONEMAPPER_LOG_SRGB : TONEMAPPER_NONE;
|
657 |
+
update_grid(p.gridSize, img, target);
|
658 |
+
|
659 |
+
// Choose launch parameters.
|
660 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
661 |
+
dim3 warpSize = getWarpSize(blockSize);
|
662 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
663 |
+
|
664 |
+
// Allocate output tensors.
|
665 |
+
torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
|
666 |
+
torch::Tensor out = torch::empty({ (p.gridSize.z - 1) / warpSize.z + 1, (p.gridSize.y - 1) / warpSize.y + 1, (p.gridSize.x - 1) / warpSize.x + 1, 1 }, opts);
|
667 |
+
|
668 |
+
p.img = make_cuda_tensor(img, p.gridSize);
|
669 |
+
p.target = make_cuda_tensor(target, p.gridSize);
|
670 |
+
p.out = make_cuda_tensor(out, p.gridSize);
|
671 |
+
|
672 |
+
// Launch CUDA kernel.
|
673 |
+
void* args[] = { &p };
|
674 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)imgLossFwdKernel, gridSize, blockSize, args, 0, stream));
|
675 |
+
|
676 |
+
return out;
|
677 |
+
}
|
678 |
+
|
679 |
+
std::tuple<torch::Tensor, torch::Tensor> image_loss_bwd(torch::Tensor img, torch::Tensor target, torch::Tensor grad, std::string loss, std::string tonemapper)
|
680 |
+
{
|
681 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
682 |
+
|
683 |
+
// Extract input parameters.
|
684 |
+
LossKernelParams p;
|
685 |
+
p.loss = strToLoss(loss);
|
686 |
+
p.tonemapper = tonemapper == "log_srgb" ? TONEMAPPER_LOG_SRGB : TONEMAPPER_NONE;
|
687 |
+
update_grid(p.gridSize, img, target);
|
688 |
+
|
689 |
+
// Choose launch parameters.
|
690 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
691 |
+
dim3 warpSize = getWarpSize(blockSize);
|
692 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
693 |
+
|
694 |
+
torch::Tensor img_grad, target_grad;
|
695 |
+
p.img = make_cuda_tensor(img, p.gridSize, &img_grad);
|
696 |
+
p.target = make_cuda_tensor(target, p.gridSize, &target_grad);
|
697 |
+
p.out = make_cuda_tensor(grad, p.gridSize);
|
698 |
+
|
699 |
+
// Launch CUDA kernel.
|
700 |
+
void* args[] = { &p };
|
701 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)imgLossBwdKernel, gridSize, blockSize, args, 0, stream));
|
702 |
+
|
703 |
+
return std::tuple<torch::Tensor, torch::Tensor>(img_grad, target_grad);
|
704 |
+
}
|
705 |
+
|
706 |
+
//------------------------------------------------------------------------
|
707 |
+
// transform function
|
708 |
+
|
709 |
+
torch::Tensor xfm_fwd(torch::Tensor points, torch::Tensor matrix, bool isPoints, bool fp16)
|
710 |
+
{
|
711 |
+
CHECK_TENSOR(points, 3, 3);
|
712 |
+
CHECK_TENSOR(matrix, 3, 4);
|
713 |
+
|
714 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
715 |
+
|
716 |
+
// Extract input parameters.
|
717 |
+
XfmKernelParams p;
|
718 |
+
p.out.fp16 = fp16;
|
719 |
+
p.isPoints = isPoints;
|
720 |
+
p.gridSize.x = points.size(1);
|
721 |
+
p.gridSize.y = 1;
|
722 |
+
p.gridSize.z = std::max(matrix.size(0), points.size(0));
|
723 |
+
|
724 |
+
// Choose launch parameters.
|
725 |
+
dim3 blockSize(BLOCK_X * BLOCK_Y, 1, 1);
|
726 |
+
dim3 warpSize = getWarpSize(blockSize);
|
727 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
728 |
+
|
729 |
+
// Allocate output tensors.
|
730 |
+
torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
|
731 |
+
torch::Tensor out = isPoints ? torch::empty({ matrix.size(0), points.size(1), 4 }, opts) : torch::empty({ matrix.size(0), points.size(1), 3 }, opts);
|
732 |
+
|
733 |
+
p.points = make_cuda_tensor(points, p.gridSize);
|
734 |
+
p.matrix = make_cuda_tensor(matrix, p.gridSize);
|
735 |
+
p.out = make_cuda_tensor(out, p.gridSize);
|
736 |
+
|
737 |
+
// Launch CUDA kernel.
|
738 |
+
void* args[] = { &p };
|
739 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)xfmPointsFwdKernel, gridSize, blockSize, args, 0, stream));
|
740 |
+
|
741 |
+
return out;
|
742 |
+
}
|
743 |
+
|
744 |
+
torch::Tensor xfm_bwd(torch::Tensor points, torch::Tensor matrix, torch::Tensor grad, bool isPoints)
|
745 |
+
{
|
746 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
747 |
+
|
748 |
+
// Extract input parameters.
|
749 |
+
XfmKernelParams p;
|
750 |
+
p.isPoints = isPoints;
|
751 |
+
p.gridSize.x = points.size(1);
|
752 |
+
p.gridSize.y = 1;
|
753 |
+
p.gridSize.z = std::max(matrix.size(0), points.size(0));
|
754 |
+
|
755 |
+
// Choose launch parameters.
|
756 |
+
dim3 blockSize(BLOCK_X * BLOCK_Y, 1, 1);
|
757 |
+
dim3 warpSize = getWarpSize(blockSize);
|
758 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
759 |
+
|
760 |
+
torch::Tensor points_grad;
|
761 |
+
p.points = make_cuda_tensor(points, p.gridSize, &points_grad);
|
762 |
+
p.matrix = make_cuda_tensor(matrix, p.gridSize);
|
763 |
+
p.out = make_cuda_tensor(grad, p.gridSize);
|
764 |
+
|
765 |
+
// Launch CUDA kernel.
|
766 |
+
void* args[] = { &p };
|
767 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)xfmPointsBwdKernel, gridSize, blockSize, args, 0, stream));
|
768 |
+
|
769 |
+
return points_grad;
|
770 |
+
}
|
771 |
+
|
772 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
773 |
+
m.def("prepare_shading_normal_fwd", &prepare_shading_normal_fwd, "prepare_shading_normal_fwd");
|
774 |
+
m.def("prepare_shading_normal_bwd", &prepare_shading_normal_bwd, "prepare_shading_normal_bwd");
|
775 |
+
m.def("lambert_fwd", &lambert_fwd, "lambert_fwd");
|
776 |
+
m.def("lambert_bwd", &lambert_bwd, "lambert_bwd");
|
777 |
+
m.def("fresnel_shlick_fwd", &fresnel_shlick_fwd, "fresnel_shlick_fwd");
|
778 |
+
m.def("fresnel_shlick_bwd", &fresnel_shlick_bwd, "fresnel_shlick_bwd");
|
779 |
+
m.def("ndf_ggx_fwd", &ndf_ggx_fwd, "ndf_ggx_fwd");
|
780 |
+
m.def("ndf_ggx_bwd", &ndf_ggx_bwd, "ndf_ggx_bwd");
|
781 |
+
m.def("lambda_ggx_fwd", &lambda_ggx_fwd, "lambda_ggx_fwd");
|
782 |
+
m.def("lambda_ggx_bwd", &lambda_ggx_bwd, "lambda_ggx_bwd");
|
783 |
+
m.def("masking_smith_fwd", &masking_smith_fwd, "masking_smith_fwd");
|
784 |
+
m.def("masking_smith_bwd", &masking_smith_bwd, "masking_smith_bwd");
|
785 |
+
m.def("pbr_specular_fwd", &pbr_specular_fwd, "pbr_specular_fwd");
|
786 |
+
m.def("pbr_specular_bwd", &pbr_specular_bwd, "pbr_specular_bwd");
|
787 |
+
m.def("pbr_bsdf_fwd", &pbr_bsdf_fwd, "pbr_bsdf_fwd");
|
788 |
+
m.def("pbr_bsdf_bwd", &pbr_bsdf_bwd, "pbr_bsdf_bwd");
|
789 |
+
m.def("image_loss_fwd", &image_loss_fwd, "image_loss_fwd");
|
790 |
+
m.def("image_loss_bwd", &image_loss_bwd, "image_loss_bwd");
|
791 |
+
m.def("xfm_fwd", &xfm_fwd, "xfm_fwd");
|
792 |
+
m.def("xfm_bwd", &xfm_bwd, "xfm_bwd");
|
793 |
+
}
|
nvdiffmodeling/src/renderutils/c_src/vec3f.h
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#pragma once
|
10 |
+
|
11 |
+
struct vec3f
|
12 |
+
{
|
13 |
+
float x, y, z;
|
14 |
+
|
15 |
+
#ifdef __CUDACC__
|
16 |
+
__device__ vec3f() { }
|
17 |
+
__device__ vec3f(float v) { x = v; y = v; z = v; }
|
18 |
+
__device__ vec3f(float _x, float _y, float _z) { x = _x; y = _y; z = _z; }
|
19 |
+
__device__ vec3f(float3 v) { x = v.x; y = v.y; z = v.z; }
|
20 |
+
|
21 |
+
__device__ inline vec3f& operator+=(const vec3f& b) { x += b.x; y += b.y; z += b.z; return *this; }
|
22 |
+
__device__ inline vec3f& operator-=(const vec3f& b) { x -= b.x; y -= b.y; z -= b.z; return *this; }
|
23 |
+
__device__ inline vec3f& operator*=(const vec3f& b) { x *= b.x; y *= b.y; z *= b.z; return *this; }
|
24 |
+
__device__ inline vec3f& operator/=(const vec3f& b) { x /= b.x; y /= b.y; z /= b.z; return *this; }
|
25 |
+
#endif
|
26 |
+
};
|
27 |
+
|
28 |
+
#ifdef __CUDACC__
|
29 |
+
__device__ static inline vec3f operator+(const vec3f& a, const vec3f& b) { return vec3f(a.x + b.x, a.y + b.y, a.z + b.z); }
|
30 |
+
__device__ static inline vec3f operator-(const vec3f& a, const vec3f& b) { return vec3f(a.x - b.x, a.y - b.y, a.z - b.z); }
|
31 |
+
__device__ static inline vec3f operator*(const vec3f& a, const vec3f& b) { return vec3f(a.x * b.x, a.y * b.y, a.z * b.z); }
|
32 |
+
__device__ static inline vec3f operator/(const vec3f& a, const vec3f& b) { return vec3f(a.x / b.x, a.y / b.y, a.z / b.z); }
|
33 |
+
__device__ static inline vec3f operator-(const vec3f& a) { return vec3f(-a.x, -a.y, -a.z); }
|
34 |
+
|
35 |
+
__device__ static inline float sum(vec3f a)
|
36 |
+
{
|
37 |
+
return a.x + a.y + a.z;
|
38 |
+
}
|
39 |
+
|
40 |
+
__device__ static inline vec3f cross(vec3f a, vec3f b)
|
41 |
+
{
|
42 |
+
vec3f out;
|
43 |
+
out.x = a.y * b.z - a.z * b.y;
|
44 |
+
out.y = a.z * b.x - a.x * b.z;
|
45 |
+
out.z = a.x * b.y - a.y * b.x;
|
46 |
+
return out;
|
47 |
+
}
|
48 |
+
|
49 |
+
__device__ static inline void bwdCross(vec3f a, vec3f b, vec3f &d_a, vec3f &d_b, vec3f d_out)
|
50 |
+
{
|
51 |
+
d_a.x += d_out.z * b.y - d_out.y * b.z;
|
52 |
+
d_a.y += d_out.x * b.z - d_out.z * b.x;
|
53 |
+
d_a.z += d_out.y * b.x - d_out.x * b.y;
|
54 |
+
|
55 |
+
d_b.x += d_out.y * a.z - d_out.z * a.y;
|
56 |
+
d_b.y += d_out.z * a.x - d_out.x * a.z;
|
57 |
+
d_b.z += d_out.x * a.y - d_out.y * a.x;
|
58 |
+
}
|
59 |
+
|
60 |
+
__device__ static inline float dot(vec3f a, vec3f b)
|
61 |
+
{
|
62 |
+
return a.x * b.x + a.y * b.y + a.z * b.z;
|
63 |
+
}
|
64 |
+
|
65 |
+
__device__ static inline void bwdDot(vec3f a, vec3f b, vec3f& d_a, vec3f& d_b, float d_out)
|
66 |
+
{
|
67 |
+
d_a.x += d_out * b.x; d_a.y += d_out * b.y; d_a.z += d_out * b.z;
|
68 |
+
d_b.x += d_out * a.x; d_b.y += d_out * a.y; d_b.z += d_out * a.z;
|
69 |
+
}
|
70 |
+
|
71 |
+
__device__ static inline vec3f reflect(vec3f x, vec3f n)
|
72 |
+
{
|
73 |
+
return n * 2.0f * dot(n, x) - x;
|
74 |
+
}
|
75 |
+
|
76 |
+
__device__ static inline void bwdReflect(vec3f x, vec3f n, vec3f& d_x, vec3f& d_n, const vec3f d_out)
|
77 |
+
{
|
78 |
+
d_x.x += d_out.x * (2 * n.x * n.x - 1) + d_out.y * (2 * n.x * n.y) + d_out.z * (2 * n.x * n.z);
|
79 |
+
d_x.y += d_out.x * (2 * n.x * n.y) + d_out.y * (2 * n.y * n.y - 1) + d_out.z * (2 * n.y * n.z);
|
80 |
+
d_x.z += d_out.x * (2 * n.x * n.z) + d_out.y * (2 * n.y * n.z) + d_out.z * (2 * n.z * n.z - 1);
|
81 |
+
|
82 |
+
d_n.x += d_out.x * (2 * (2 * n.x * x.x + n.y * x.y + n.z * x.z)) + d_out.y * (2 * n.y * x.x) + d_out.z * (2 * n.z * x.x);
|
83 |
+
d_n.y += d_out.x * (2 * n.x * x.y) + d_out.y * (2 * (n.x * x.x + 2 * n.y * x.y + n.z * x.z)) + d_out.z * (2 * n.z * x.y);
|
84 |
+
d_n.z += d_out.x * (2 * n.x * x.z) + d_out.y * (2 * n.y * x.z) + d_out.z * (2 * (n.x * x.x + n.y * x.y + 2 * n.z * x.z));
|
85 |
+
}
|
86 |
+
|
87 |
+
__device__ static inline vec3f safeNormalize(vec3f v)
|
88 |
+
{
|
89 |
+
float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z);
|
90 |
+
return l > 0.0f ? (v / l) : vec3f(0.0f);
|
91 |
+
}
|
92 |
+
|
93 |
+
__device__ static inline void bwdSafeNormalize(const vec3f v, vec3f& d_v, const vec3f d_out)
|
94 |
+
{
|
95 |
+
|
96 |
+
float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z);
|
97 |
+
if (l > 0.0f)
|
98 |
+
{
|
99 |
+
float fac = 1.0 / powf(v.x * v.x + v.y * v.y + v.z * v.z, 1.5f);
|
100 |
+
d_v.x += (d_out.x * (v.y * v.y + v.z * v.z) - d_out.y * (v.x * v.y) - d_out.z * (v.x * v.z)) * fac;
|
101 |
+
d_v.y += (d_out.y * (v.x * v.x + v.z * v.z) - d_out.x * (v.y * v.x) - d_out.z * (v.y * v.z)) * fac;
|
102 |
+
d_v.z += (d_out.z * (v.x * v.x + v.y * v.y) - d_out.x * (v.z * v.x) - d_out.y * (v.z * v.y)) * fac;
|
103 |
+
}
|
104 |
+
}
|
105 |
+
|
106 |
+
#endif
|
nvdiffmodeling/src/renderutils/c_src/vec4f.h
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#pragma once
|
10 |
+
|
11 |
+
struct vec4f
|
12 |
+
{
|
13 |
+
float x, y, z, w;
|
14 |
+
|
15 |
+
#ifdef __CUDACC__
|
16 |
+
__device__ vec4f() { }
|
17 |
+
__device__ vec4f(float v) { x = v; y = v; z = v; w = v; }
|
18 |
+
__device__ vec4f(float _x, float _y, float _z, float _w) { x = _x; y = _y; z = _z; w = _w; }
|
19 |
+
__device__ vec4f(float4 v) { x = v.x; y = v.y; z = v.z; w = v.w; }
|
20 |
+
#endif
|
21 |
+
};
|
22 |
+
|
nvdiffmodeling/src/renderutils/loss.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
#----------------------------------------------------------------------------
|
12 |
+
# HDR image losses
|
13 |
+
#----------------------------------------------------------------------------
|
14 |
+
|
15 |
+
def _tonemap_srgb(f):
|
16 |
+
return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f)
|
17 |
+
|
18 |
+
def _SMAPE(img, target, eps=0.01):
|
19 |
+
nom = torch.abs(img - target)
|
20 |
+
denom = torch.abs(img) + torch.abs(target) + 0.01
|
21 |
+
return torch.mean(nom / denom)
|
22 |
+
|
23 |
+
def _RELMSE(img, target, eps=0.1):
|
24 |
+
nom = (img - target) * (img - target)
|
25 |
+
denom = img * img + target * target + 0.1
|
26 |
+
return torch.mean(nom / denom)
|
27 |
+
|
28 |
+
def image_loss_fn(img, target, loss, tonemapper):
|
29 |
+
if tonemapper == 'log_srgb':
|
30 |
+
img = _tonemap_srgb(torch.log(torch.clamp(img, min=0, max=65535) + 1))
|
31 |
+
target = _tonemap_srgb(torch.log(torch.clamp(target, min=0, max=65535) + 1))
|
32 |
+
|
33 |
+
if loss == 'mse':
|
34 |
+
return torch.nn.functional.mse_loss(img, target)
|
35 |
+
elif loss == 'smape':
|
36 |
+
return _SMAPE(img, target)
|
37 |
+
elif loss == 'relmse':
|
38 |
+
return _RELMSE(img, target)
|
39 |
+
else:
|
40 |
+
return torch.nn.functional.l1_loss(img, target)
|
nvdiffmodeling/src/renderutils/ops.py
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import torch
|
13 |
+
import torch.utils.cpp_extension
|
14 |
+
|
15 |
+
from .bsdf import *
|
16 |
+
from .loss import *
|
17 |
+
|
18 |
+
#----------------------------------------------------------------------------
|
19 |
+
# C++/Cuda plugin compiler/loader.
|
20 |
+
|
21 |
+
_plugin = None
|
22 |
+
if _plugin is None:
|
23 |
+
|
24 |
+
# Make sure we can find the necessary compiler and libary binaries.
|
25 |
+
if os.name == 'nt':
|
26 |
+
def find_cl_path():
|
27 |
+
import glob
|
28 |
+
for edition in ['Enterprise', 'Professional', 'BuildTools', 'Community']:
|
29 |
+
paths = sorted(glob.glob(r"C:\Program Files (x86)\Microsoft Visual Studio\*\%s\VC\Tools\MSVC\*\bin\Hostx64\x64" % edition), reverse=True)
|
30 |
+
if paths:
|
31 |
+
return paths[0]
|
32 |
+
|
33 |
+
# If cl.exe is not on path, try to find it.
|
34 |
+
if os.system("where cl.exe >nul 2>nul") != 0:
|
35 |
+
cl_path = find_cl_path()
|
36 |
+
if cl_path is None:
|
37 |
+
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
38 |
+
os.environ['PATH'] += ';' + cl_path
|
39 |
+
|
40 |
+
# Linker options.
|
41 |
+
if os.name == 'posix':
|
42 |
+
ldflags = ['-lcuda']
|
43 |
+
elif os.name == 'nt':
|
44 |
+
ldflags = ['/DEFAULTLIB:cuda']
|
45 |
+
|
46 |
+
# List of sources.
|
47 |
+
source_files = [
|
48 |
+
'c_src/mesh.cu',
|
49 |
+
'c_src/loss.cu',
|
50 |
+
'c_src/bsdf.cu',
|
51 |
+
'c_src/normal.cu',
|
52 |
+
'c_src/common.cpp',
|
53 |
+
'c_src/torch_bindings.cpp'
|
54 |
+
]
|
55 |
+
|
56 |
+
# Some containers set this to contain old architectures that won't compile. We only need the one installed in the machine.
|
57 |
+
os.environ['TORCH_CUDA_ARCH_LIST'] = ''
|
58 |
+
|
59 |
+
# Compile and load.
|
60 |
+
source_paths = [os.path.join(os.path.dirname(__file__), fn) for fn in source_files]
|
61 |
+
torch.utils.cpp_extension.load(name='renderutils_plugin', sources=source_paths, extra_ldflags=ldflags, with_cuda=True, verbose=True)
|
62 |
+
|
63 |
+
# Import, cache, and return the compiled module.
|
64 |
+
import renderutils_plugin
|
65 |
+
_plugin = renderutils_plugin
|
66 |
+
|
67 |
+
#----------------------------------------------------------------------------
|
68 |
+
# Internal kernels, just used for testing functionality
|
69 |
+
|
70 |
+
class _fresnel_shlick_func(torch.autograd.Function):
|
71 |
+
@staticmethod
|
72 |
+
def forward(ctx, f0, f90, cosTheta):
|
73 |
+
out = _plugin.fresnel_shlick_fwd(f0, f90, cosTheta, False)
|
74 |
+
ctx.save_for_backward(f0, f90, cosTheta)
|
75 |
+
return out
|
76 |
+
|
77 |
+
@staticmethod
|
78 |
+
def backward(ctx, dout):
|
79 |
+
f0, f90, cosTheta = ctx.saved_variables
|
80 |
+
return _plugin.fresnel_shlick_bwd(f0, f90, cosTheta, dout) + (None,)
|
81 |
+
|
82 |
+
def _fresnel_shlick(f0, f90, cosTheta, use_python=False):
|
83 |
+
if use_python:
|
84 |
+
out = bsdf_fresnel_shlick(f0, f90, cosTheta)
|
85 |
+
else:
|
86 |
+
out = _fresnel_shlick_func.apply(f0, f90, cosTheta)
|
87 |
+
|
88 |
+
if torch.is_anomaly_enabled():
|
89 |
+
assert torch.all(torch.isfinite(out)), "Output of _fresnel_shlick contains inf or NaN"
|
90 |
+
return out
|
91 |
+
|
92 |
+
|
93 |
+
class _ndf_ggx_func(torch.autograd.Function):
|
94 |
+
@staticmethod
|
95 |
+
def forward(ctx, alphaSqr, cosTheta):
|
96 |
+
out = _plugin.ndf_ggx_fwd(alphaSqr, cosTheta, False)
|
97 |
+
ctx.save_for_backward(alphaSqr, cosTheta)
|
98 |
+
return out
|
99 |
+
|
100 |
+
@staticmethod
|
101 |
+
def backward(ctx, dout):
|
102 |
+
alphaSqr, cosTheta = ctx.saved_variables
|
103 |
+
return _plugin.ndf_ggx_bwd(alphaSqr, cosTheta, dout) + (None,)
|
104 |
+
|
105 |
+
def _ndf_ggx(alphaSqr, cosTheta, use_python=False):
|
106 |
+
if use_python:
|
107 |
+
out = bsdf_ndf_ggx(alphaSqr, cosTheta)
|
108 |
+
else:
|
109 |
+
out = _ndf_ggx_func.apply(alphaSqr, cosTheta)
|
110 |
+
|
111 |
+
if torch.is_anomaly_enabled():
|
112 |
+
assert torch.all(torch.isfinite(out)), "Output of _ndf_ggx contains inf or NaN"
|
113 |
+
return out
|
114 |
+
|
115 |
+
class _lambda_ggx_func(torch.autograd.Function):
|
116 |
+
@staticmethod
|
117 |
+
def forward(ctx, alphaSqr, cosTheta):
|
118 |
+
out = _plugin.lambda_ggx_fwd(alphaSqr, cosTheta, False)
|
119 |
+
ctx.save_for_backward(alphaSqr, cosTheta)
|
120 |
+
return out
|
121 |
+
|
122 |
+
@staticmethod
|
123 |
+
def backward(ctx, dout):
|
124 |
+
alphaSqr, cosTheta = ctx.saved_variables
|
125 |
+
return _plugin.lambda_ggx_bwd(alphaSqr, cosTheta, dout) + (None,)
|
126 |
+
|
127 |
+
def _lambda_ggx(alphaSqr, cosTheta, use_python=False):
|
128 |
+
if use_python:
|
129 |
+
out = bsdf_lambda_ggx(alphaSqr, cosTheta)
|
130 |
+
else:
|
131 |
+
out = _lambda_ggx_func.apply(alphaSqr, cosTheta)
|
132 |
+
|
133 |
+
if torch.is_anomaly_enabled():
|
134 |
+
assert torch.all(torch.isfinite(out)), "Output of _lambda_ggx contains inf or NaN"
|
135 |
+
return out
|
136 |
+
|
137 |
+
class _masking_smith_func(torch.autograd.Function):
|
138 |
+
@staticmethod
|
139 |
+
def forward(ctx, alphaSqr, cosThetaI, cosThetaO):
|
140 |
+
ctx.save_for_backward(alphaSqr, cosThetaI, cosThetaO)
|
141 |
+
out = _plugin.masking_smith_fwd(alphaSqr, cosThetaI, cosThetaO, False)
|
142 |
+
return out
|
143 |
+
|
144 |
+
@staticmethod
|
145 |
+
def backward(ctx, dout):
|
146 |
+
alphaSqr, cosThetaI, cosThetaO = ctx.saved_variables
|
147 |
+
return _plugin.masking_smith_bwd(alphaSqr, cosThetaI, cosThetaO, dout) + (None,)
|
148 |
+
|
149 |
+
def _masking_smith(alphaSqr, cosThetaI, cosThetaO, use_python=False):
|
150 |
+
if use_python:
|
151 |
+
out = bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO)
|
152 |
+
else:
|
153 |
+
out = _masking_smith_func.apply(alphaSqr, cosThetaI, cosThetaO)
|
154 |
+
|
155 |
+
if torch.is_anomaly_enabled():
|
156 |
+
assert torch.all(torch.isfinite(out)), "Output of _masking_smith contains inf or NaN"
|
157 |
+
return out
|
158 |
+
|
159 |
+
#----------------------------------------------------------------------------
|
160 |
+
# Shading normal setup (bump mapping + bent normals)
|
161 |
+
|
162 |
+
class _prepare_shading_normal_func(torch.autograd.Function):
|
163 |
+
@staticmethod
|
164 |
+
def forward(ctx, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl):
|
165 |
+
ctx.two_sided_shading, ctx.opengl = two_sided_shading, opengl
|
166 |
+
out = _plugin.prepare_shading_normal_fwd(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl, False)
|
167 |
+
ctx.save_for_backward(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm)
|
168 |
+
return out
|
169 |
+
|
170 |
+
@staticmethod
|
171 |
+
def backward(ctx, dout):
|
172 |
+
pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm = ctx.saved_variables
|
173 |
+
return _plugin.prepare_shading_normal_bwd(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, dout, ctx.two_sided_shading, ctx.opengl) + (None, None, None)
|
174 |
+
|
175 |
+
def prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading=True, opengl=True, use_python=False):
|
176 |
+
'''Takes care of all corner cases and produces a final normal used for shading:
|
177 |
+
- Constructs tangent space
|
178 |
+
- Flips normal direction based on geometric normal for two sided Shading
|
179 |
+
- Perturbs shading normal by normal map
|
180 |
+
- Bends backfacing normals towards the camera to avoid shading artifacts
|
181 |
+
|
182 |
+
All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
pos: World space g-buffer position.
|
186 |
+
view_pos: Camera position in world space (typically using broadcasting).
|
187 |
+
perturbed_nrm: Trangent-space normal perturbation from normal map lookup.
|
188 |
+
smooth_nrm: Interpolated vertex normals.
|
189 |
+
smooth_tng: Interpolated vertex tangents.
|
190 |
+
geom_nrm: Geometric (face) normals.
|
191 |
+
two_sided_shading: Use one/two sided shading
|
192 |
+
opengl: Use OpenGL/DirectX normal map conventions
|
193 |
+
use_python: Use PyTorch implementation (for validation)
|
194 |
+
Returns:
|
195 |
+
Final shading normal
|
196 |
+
'''
|
197 |
+
|
198 |
+
if perturbed_nrm is None:
|
199 |
+
perturbed_nrm = torch.tensor([0, 0, 1], dtype=torch.float32, device='cuda', requires_grad=False)[None, None, None, ...]
|
200 |
+
|
201 |
+
if use_python:
|
202 |
+
out = bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl)
|
203 |
+
else:
|
204 |
+
out = _prepare_shading_normal_func.apply(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl)
|
205 |
+
|
206 |
+
if torch.is_anomaly_enabled():
|
207 |
+
assert torch.all(torch.isfinite(out)), "Output of prepare_shading_normal contains inf or NaN"
|
208 |
+
return out
|
209 |
+
|
210 |
+
#----------------------------------------------------------------------------
|
211 |
+
# BSDF functions
|
212 |
+
|
213 |
+
class _lambert_func(torch.autograd.Function):
|
214 |
+
@staticmethod
|
215 |
+
def forward(ctx, nrm, wi):
|
216 |
+
out = _plugin.lambert_fwd(nrm, wi, False)
|
217 |
+
ctx.save_for_backward(nrm, wi)
|
218 |
+
return out
|
219 |
+
|
220 |
+
@staticmethod
|
221 |
+
def backward(ctx, dout):
|
222 |
+
nrm, wi = ctx.saved_variables
|
223 |
+
return _plugin.lambert_bwd(nrm, wi, dout) + (None,)
|
224 |
+
|
225 |
+
def lambert(nrm, wi, use_python=False):
|
226 |
+
'''Lambertian bsdf.
|
227 |
+
All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent.
|
228 |
+
|
229 |
+
Args:
|
230 |
+
nrm: World space shading normal.
|
231 |
+
wi: World space light vector.
|
232 |
+
use_python: Use PyTorch implementation (for validation)
|
233 |
+
|
234 |
+
Returns:
|
235 |
+
Shaded diffuse value with shape [minibatch_size, height, width, 1]
|
236 |
+
'''
|
237 |
+
|
238 |
+
if use_python:
|
239 |
+
out = bsdf_lambert(nrm, wi)
|
240 |
+
else:
|
241 |
+
out = _lambert_func.apply(nrm, wi)
|
242 |
+
|
243 |
+
if torch.is_anomaly_enabled():
|
244 |
+
assert torch.all(torch.isfinite(out)), "Output of lambert contains inf or NaN"
|
245 |
+
return out
|
246 |
+
|
247 |
+
class _pbr_specular_func(torch.autograd.Function):
|
248 |
+
@staticmethod
|
249 |
+
def forward(ctx, col, nrm, wo, wi, alpha, min_roughness):
|
250 |
+
ctx.save_for_backward(col, nrm, wo, wi, alpha)
|
251 |
+
ctx.min_roughness = min_roughness
|
252 |
+
out = _plugin.pbr_specular_fwd(col, nrm, wo, wi, alpha, min_roughness, False)
|
253 |
+
return out
|
254 |
+
|
255 |
+
@staticmethod
|
256 |
+
def backward(ctx, dout):
|
257 |
+
col, nrm, wo, wi, alpha = ctx.saved_variables
|
258 |
+
return _plugin.pbr_specular_bwd(col, nrm, wo, wi, alpha, ctx.min_roughness, dout) + (None, None)
|
259 |
+
|
260 |
+
def pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08, use_python=False):
|
261 |
+
'''Physically-based specular bsdf.
|
262 |
+
All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted.
|
263 |
+
|
264 |
+
Args:
|
265 |
+
col: Specular lobe color
|
266 |
+
nrm: World space shading normal.
|
267 |
+
wo: World space camera vector.
|
268 |
+
wi: World space light vector
|
269 |
+
alpha: Specular roughness parameter with shape [minibatch_size, height, width, 1]
|
270 |
+
min_roughness: Scalar roughness clamping threshold
|
271 |
+
|
272 |
+
use_python: Use PyTorch implementation (for validation)
|
273 |
+
Returns:
|
274 |
+
Shaded specular color
|
275 |
+
'''
|
276 |
+
|
277 |
+
if use_python:
|
278 |
+
out = bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=min_roughness)
|
279 |
+
else:
|
280 |
+
out = _pbr_specular_func.apply(col, nrm, wo, wi, alpha, min_roughness)
|
281 |
+
|
282 |
+
if torch.is_anomaly_enabled():
|
283 |
+
assert torch.all(torch.isfinite(out)), "Output of pbr_specular contains inf or NaN"
|
284 |
+
return out
|
285 |
+
|
286 |
+
class _pbr_bsdf_func(torch.autograd.Function):
|
287 |
+
@staticmethod
|
288 |
+
def forward(ctx, kd, arm, pos, nrm, view_pos, light_pos, min_roughness):
|
289 |
+
ctx.save_for_backward(kd, arm, pos, nrm, view_pos, light_pos)
|
290 |
+
ctx.min_roughness = min_roughness
|
291 |
+
out = _plugin.pbr_bsdf_fwd(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, False)
|
292 |
+
return out
|
293 |
+
|
294 |
+
@staticmethod
|
295 |
+
def backward(ctx, dout):
|
296 |
+
kd, arm, pos, nrm, view_pos, light_pos = ctx.saved_variables
|
297 |
+
return _plugin.pbr_bsdf_bwd(kd, arm, pos, nrm, view_pos, light_pos, ctx.min_roughness, dout) + (None, None)
|
298 |
+
|
299 |
+
def pbr_bsdf(kd, arm, pos, nrm, view_pos, light_pos, min_roughness=0.08, use_python=False):
|
300 |
+
'''Physically-based bsdf, both diffuse & specular lobes
|
301 |
+
All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted.
|
302 |
+
|
303 |
+
Args:
|
304 |
+
kd: Diffuse albedo.
|
305 |
+
arm: Specular parameters (attenuation, linear roughness, metalness).
|
306 |
+
pos: World space position.
|
307 |
+
nrm: World space shading normal.
|
308 |
+
view_pos: Camera position in world space, typically using broadcasting.
|
309 |
+
light_pos: Light position in world space, typically using broadcasting.
|
310 |
+
min_roughness: Scalar roughness clamping threshold
|
311 |
+
|
312 |
+
use_python: Use PyTorch implementation (for validation)
|
313 |
+
|
314 |
+
Returns:
|
315 |
+
Shaded color.
|
316 |
+
'''
|
317 |
+
|
318 |
+
if use_python:
|
319 |
+
out = bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness=min_roughness)
|
320 |
+
else:
|
321 |
+
out = _pbr_bsdf_func.apply(kd, arm, pos, nrm, view_pos, light_pos, min_roughness)
|
322 |
+
|
323 |
+
if torch.is_anomaly_enabled():
|
324 |
+
assert torch.all(torch.isfinite(out)), "Output of pbr_bsdf contains inf or NaN"
|
325 |
+
return out
|
326 |
+
|
327 |
+
#----------------------------------------------------------------------------
|
328 |
+
# Fast image loss function
|
329 |
+
|
330 |
+
class _image_loss_func(torch.autograd.Function):
|
331 |
+
@staticmethod
|
332 |
+
def forward(ctx, img, target, loss, tonemapper):
|
333 |
+
ctx.loss, ctx.tonemapper = loss, tonemapper
|
334 |
+
ctx.save_for_backward(img, target)
|
335 |
+
out = _plugin.image_loss_fwd(img, target, loss, tonemapper, False)
|
336 |
+
return out
|
337 |
+
|
338 |
+
@staticmethod
|
339 |
+
def backward(ctx, dout):
|
340 |
+
img, target = ctx.saved_variables
|
341 |
+
return _plugin.image_loss_bwd(img, target, dout, ctx.loss, ctx.tonemapper) + (None, None, None)
|
342 |
+
|
343 |
+
def image_loss(img, target, loss='l1', tonemapper='none', use_python=False):
|
344 |
+
'''Compute HDR image loss. Combines tonemapping and loss into a single kernel for better perf.
|
345 |
+
All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted.
|
346 |
+
|
347 |
+
Args:
|
348 |
+
img: Input image.
|
349 |
+
target: Target (reference) image.
|
350 |
+
loss: Type of loss. Valid options are ['l1', 'mse', 'smape', 'relmse']
|
351 |
+
tonemapper: Tonemapping operations. Valid options are ['none', 'log_srgb']
|
352 |
+
use_python: Use PyTorch implementation (for validation)
|
353 |
+
|
354 |
+
Returns:
|
355 |
+
Image space loss (scalar value).
|
356 |
+
'''
|
357 |
+
if use_python:
|
358 |
+
out = image_loss_fn(img, target, loss, tonemapper)
|
359 |
+
else:
|
360 |
+
out = _image_loss_func.apply(img, target, loss, tonemapper)
|
361 |
+
out = torch.sum(out) / (img.shape[0]*img.shape[1]*img.shape[2])
|
362 |
+
|
363 |
+
if torch.is_anomaly_enabled():
|
364 |
+
assert torch.all(torch.isfinite(out)), "Output of image_loss contains inf or NaN"
|
365 |
+
return out
|
366 |
+
|
367 |
+
#----------------------------------------------------------------------------
|
368 |
+
# Transform points function
|
369 |
+
|
370 |
+
class _xfm_func(torch.autograd.Function):
|
371 |
+
@staticmethod
|
372 |
+
def forward(ctx, points, matrix, isPoints):
|
373 |
+
ctx.save_for_backward(points, matrix)
|
374 |
+
ctx.isPoints = isPoints
|
375 |
+
out = _plugin.xfm_fwd(points, matrix, isPoints, False)
|
376 |
+
return out
|
377 |
+
|
378 |
+
@staticmethod
|
379 |
+
def backward(ctx, dout):
|
380 |
+
points, matrix = ctx.saved_variables
|
381 |
+
return (_plugin.xfm_bwd(points, matrix, dout, ctx.isPoints),) + (None, None, None)
|
382 |
+
|
383 |
+
def xfm_points(points, matrix, use_python=False):
|
384 |
+
'''Transform points.
|
385 |
+
Note: this method does not back-propagate matrix gradients by default for performance reasons. For matrix gradients,
|
386 |
+
enable use_python=True or use torch.matmul instead.
|
387 |
+
|
388 |
+
Args:
|
389 |
+
points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
|
390 |
+
matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
|
391 |
+
use_python: Use PyTorch's torch.matmul (for validation)
|
392 |
+
Returns:
|
393 |
+
Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
|
394 |
+
'''
|
395 |
+
if use_python:
|
396 |
+
out = torch.matmul(torch.nn.functional.pad(points, pad=(0,1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2))
|
397 |
+
else:
|
398 |
+
out = _xfm_func.apply(points, matrix, True)
|
399 |
+
|
400 |
+
if torch.is_anomaly_enabled():
|
401 |
+
assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN"
|
402 |
+
return out
|
403 |
+
|
404 |
+
def xfm_vectors(vectors, matrix, use_python=False):
|
405 |
+
'''Transform vectors.
|
406 |
+
Note: this method does not back-propagate matrix gradients by default for performance reasons. For matrix gradients,
|
407 |
+
enable use_python=True or use torch.matmul instead.
|
408 |
+
|
409 |
+
Args:
|
410 |
+
vectors: Tensor containing 3D vectors with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
|
411 |
+
matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
|
412 |
+
use_python: Use PyTorch's torch.matmul (for validation)
|
413 |
+
|
414 |
+
Returns:
|
415 |
+
Transformed vectors in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
|
416 |
+
'''
|
417 |
+
|
418 |
+
if use_python:
|
419 |
+
out = torch.matmul(torch.nn.functional.pad(vectors, pad=(0,1), mode='constant', value=0.0), torch.transpose(matrix, 1, 2))[..., 0:3].contiguous()
|
420 |
+
else:
|
421 |
+
out = _xfm_func.apply(vectors, matrix, False)
|
422 |
+
|
423 |
+
if torch.is_anomaly_enabled():
|
424 |
+
assert torch.all(torch.isfinite(out)), "Output of xfm_vectors contains inf or NaN"
|
425 |
+
return out
|
nvdiffmodeling/src/renderutils/tests/test_bsdf.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
import os
|
13 |
+
import sys
|
14 |
+
sys.path.insert(0, os.path.join(sys.path[0], '../..'))
|
15 |
+
import renderutils as ru
|
16 |
+
|
17 |
+
RES = 4
|
18 |
+
DTYPE = torch.float32
|
19 |
+
|
20 |
+
def relative_loss(name, ref, cuda):
|
21 |
+
ref = ref.float()
|
22 |
+
cuda = cuda.float()
|
23 |
+
print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item())
|
24 |
+
|
25 |
+
def test_normal():
|
26 |
+
pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
27 |
+
pos_ref = pos_cuda.clone().detach().requires_grad_(True)
|
28 |
+
view_pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
29 |
+
view_pos_ref = view_pos_cuda.clone().detach().requires_grad_(True)
|
30 |
+
perturbed_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
31 |
+
perturbed_nrm_ref = perturbed_nrm_cuda.clone().detach().requires_grad_(True)
|
32 |
+
smooth_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
33 |
+
smooth_nrm_ref = smooth_nrm_cuda.clone().detach().requires_grad_(True)
|
34 |
+
smooth_tng_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
35 |
+
smooth_tng_ref = smooth_tng_cuda.clone().detach().requires_grad_(True)
|
36 |
+
geom_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
37 |
+
geom_nrm_ref = geom_nrm_cuda.clone().detach().requires_grad_(True)
|
38 |
+
target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda')
|
39 |
+
|
40 |
+
ref = ru.prepare_shading_normal(pos_ref, view_pos_ref, perturbed_nrm_ref, smooth_nrm_ref, smooth_tng_ref, geom_nrm_ref, True, use_python=True)
|
41 |
+
ref_loss = torch.nn.MSELoss()(ref, target)
|
42 |
+
ref_loss.backward()
|
43 |
+
|
44 |
+
cuda = ru.prepare_shading_normal(pos_cuda, view_pos_cuda, perturbed_nrm_cuda, smooth_nrm_cuda, smooth_tng_cuda, geom_nrm_cuda, True)
|
45 |
+
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
46 |
+
cuda_loss.backward()
|
47 |
+
|
48 |
+
print("-------------------------------------------------------------")
|
49 |
+
print(" bent normal")
|
50 |
+
print("-------------------------------------------------------------")
|
51 |
+
relative_loss("res:", ref, cuda)
|
52 |
+
relative_loss("pos:", pos_ref.grad, pos_cuda.grad)
|
53 |
+
relative_loss("view_pos:", view_pos_ref.grad, view_pos_cuda.grad)
|
54 |
+
relative_loss("perturbed_nrm:", perturbed_nrm_ref.grad, perturbed_nrm_cuda.grad)
|
55 |
+
relative_loss("smooth_nrm:", smooth_nrm_ref.grad, smooth_nrm_cuda.grad)
|
56 |
+
relative_loss("smooth_tng:", smooth_tng_ref.grad, smooth_tng_cuda.grad)
|
57 |
+
relative_loss("geom_nrm:", geom_nrm_ref.grad, geom_nrm_cuda.grad)
|
58 |
+
|
59 |
+
def test_schlick():
|
60 |
+
f0_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
61 |
+
f0_ref = f0_cuda.clone().detach().requires_grad_(True)
|
62 |
+
f90_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
63 |
+
f90_ref = f90_cuda.clone().detach().requires_grad_(True)
|
64 |
+
cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 2.0
|
65 |
+
cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True)
|
66 |
+
cosT_ref = cosT_cuda.clone().detach().requires_grad_(True)
|
67 |
+
target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda')
|
68 |
+
|
69 |
+
ref = ru._fresnel_shlick(f0_ref, f90_ref, cosT_ref, use_python=True)
|
70 |
+
ref_loss = torch.nn.MSELoss()(ref, target)
|
71 |
+
ref_loss.backward()
|
72 |
+
|
73 |
+
cuda = ru._fresnel_shlick(f0_cuda, f90_cuda, cosT_cuda)
|
74 |
+
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
75 |
+
cuda_loss.backward()
|
76 |
+
|
77 |
+
print("-------------------------------------------------------------")
|
78 |
+
print(" Fresnel shlick")
|
79 |
+
print("-------------------------------------------------------------")
|
80 |
+
relative_loss("res:", ref, cuda)
|
81 |
+
relative_loss("f0:", f0_ref.grad, f0_cuda.grad)
|
82 |
+
relative_loss("f90:", f90_ref.grad, f90_cuda.grad)
|
83 |
+
relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad)
|
84 |
+
|
85 |
+
def test_ndf_ggx():
|
86 |
+
alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
|
87 |
+
alphaSqr_cuda = alphaSqr_cuda.clone().detach().requires_grad_(True)
|
88 |
+
alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True)
|
89 |
+
cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 3.0 - 1
|
90 |
+
cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True)
|
91 |
+
cosT_ref = cosT_cuda.clone().detach().requires_grad_(True)
|
92 |
+
target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')
|
93 |
+
|
94 |
+
ref = ru._ndf_ggx(alphaSqr_ref, cosT_ref, use_python=True)
|
95 |
+
ref_loss = torch.nn.MSELoss()(ref, target)
|
96 |
+
ref_loss.backward()
|
97 |
+
|
98 |
+
cuda = ru._ndf_ggx(alphaSqr_cuda, cosT_cuda)
|
99 |
+
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
100 |
+
cuda_loss.backward()
|
101 |
+
|
102 |
+
print("-------------------------------------------------------------")
|
103 |
+
print(" Ndf GGX")
|
104 |
+
print("-------------------------------------------------------------")
|
105 |
+
relative_loss("res:", ref, cuda)
|
106 |
+
relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad)
|
107 |
+
relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad)
|
108 |
+
|
109 |
+
def test_lambda_ggx():
|
110 |
+
alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
|
111 |
+
alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True)
|
112 |
+
cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 3.0 - 1
|
113 |
+
cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True)
|
114 |
+
cosT_ref = cosT_cuda.clone().detach().requires_grad_(True)
|
115 |
+
target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')
|
116 |
+
|
117 |
+
ref = ru._lambda_ggx(alphaSqr_ref, cosT_ref, use_python=True)
|
118 |
+
ref_loss = torch.nn.MSELoss()(ref, target)
|
119 |
+
ref_loss.backward()
|
120 |
+
|
121 |
+
cuda = ru._lambda_ggx(alphaSqr_cuda, cosT_cuda)
|
122 |
+
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
123 |
+
cuda_loss.backward()
|
124 |
+
|
125 |
+
print("-------------------------------------------------------------")
|
126 |
+
print(" Lambda GGX")
|
127 |
+
print("-------------------------------------------------------------")
|
128 |
+
relative_loss("res:", ref, cuda)
|
129 |
+
relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad)
|
130 |
+
relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad)
|
131 |
+
|
132 |
+
def test_masking_smith():
|
133 |
+
alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
|
134 |
+
alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True)
|
135 |
+
cosThetaI_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
|
136 |
+
cosThetaI_ref = cosThetaI_cuda.clone().detach().requires_grad_(True)
|
137 |
+
cosThetaO_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
|
138 |
+
cosThetaO_ref = cosThetaO_cuda.clone().detach().requires_grad_(True)
|
139 |
+
target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')
|
140 |
+
|
141 |
+
ref = ru._masking_smith(alphaSqr_ref, cosThetaI_ref, cosThetaO_ref, use_python=True)
|
142 |
+
ref_loss = torch.nn.MSELoss()(ref, target)
|
143 |
+
ref_loss.backward()
|
144 |
+
|
145 |
+
cuda = ru._masking_smith(alphaSqr_cuda, cosThetaI_cuda, cosThetaO_cuda)
|
146 |
+
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
147 |
+
cuda_loss.backward()
|
148 |
+
|
149 |
+
print("-------------------------------------------------------------")
|
150 |
+
print(" Smith masking term")
|
151 |
+
print("-------------------------------------------------------------")
|
152 |
+
relative_loss("res:", ref, cuda)
|
153 |
+
relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad)
|
154 |
+
relative_loss("cosThetaI:", cosThetaI_ref.grad, cosThetaI_cuda.grad)
|
155 |
+
relative_loss("cosThetaO:", cosThetaO_ref.grad, cosThetaO_cuda.grad)
|
156 |
+
|
157 |
+
def test_lambert():
|
158 |
+
normals_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
159 |
+
normals_ref = normals_cuda.clone().detach().requires_grad_(True)
|
160 |
+
wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
161 |
+
wi_ref = wi_cuda.clone().detach().requires_grad_(True)
|
162 |
+
target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')
|
163 |
+
|
164 |
+
ref = ru.lambert(normals_ref, wi_ref, use_python=True)
|
165 |
+
ref_loss = torch.nn.MSELoss()(ref, target)
|
166 |
+
ref_loss.backward()
|
167 |
+
|
168 |
+
cuda = ru.lambert(normals_cuda, wi_cuda)
|
169 |
+
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
170 |
+
cuda_loss.backward()
|
171 |
+
|
172 |
+
print("-------------------------------------------------------------")
|
173 |
+
print(" Lambert")
|
174 |
+
print("-------------------------------------------------------------")
|
175 |
+
relative_loss("res:", ref, cuda)
|
176 |
+
relative_loss("nrm:", normals_ref.grad, normals_cuda.grad)
|
177 |
+
relative_loss("wi:", wi_ref.grad, wi_cuda.grad)
|
178 |
+
|
179 |
+
def test_pbr_specular():
|
180 |
+
col_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
181 |
+
col_ref = col_cuda.clone().detach().requires_grad_(True)
|
182 |
+
nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
183 |
+
nrm_ref = nrm_cuda.clone().detach().requires_grad_(True)
|
184 |
+
wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
185 |
+
wi_ref = wi_cuda.clone().detach().requires_grad_(True)
|
186 |
+
wo_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
187 |
+
wo_ref = wo_cuda.clone().detach().requires_grad_(True)
|
188 |
+
alpha_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
|
189 |
+
alpha_ref = alpha_cuda.clone().detach().requires_grad_(True)
|
190 |
+
target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda')
|
191 |
+
|
192 |
+
ref = ru.pbr_specular(col_ref, nrm_ref, wo_ref, wi_ref, alpha_ref, use_python=True)
|
193 |
+
ref_loss = torch.nn.MSELoss()(ref, target)
|
194 |
+
ref_loss.backward()
|
195 |
+
|
196 |
+
cuda = ru.pbr_specular(col_cuda, nrm_cuda, wo_cuda, wi_cuda, alpha_cuda)
|
197 |
+
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
198 |
+
cuda_loss.backward()
|
199 |
+
|
200 |
+
print("-------------------------------------------------------------")
|
201 |
+
print(" Pbr specular")
|
202 |
+
print("-------------------------------------------------------------")
|
203 |
+
|
204 |
+
relative_loss("res:", ref, cuda)
|
205 |
+
if col_ref.grad is not None:
|
206 |
+
relative_loss("col:", col_ref.grad, col_cuda.grad)
|
207 |
+
if nrm_ref.grad is not None:
|
208 |
+
relative_loss("nrm:", nrm_ref.grad, nrm_cuda.grad)
|
209 |
+
if wi_ref.grad is not None:
|
210 |
+
relative_loss("wi:", wi_ref.grad, wi_cuda.grad)
|
211 |
+
if wo_ref.grad is not None:
|
212 |
+
relative_loss("wo:", wo_ref.grad, wo_cuda.grad)
|
213 |
+
if alpha_ref.grad is not None:
|
214 |
+
relative_loss("alpha:", alpha_ref.grad, alpha_cuda.grad)
|
215 |
+
|
216 |
+
def test_pbr_bsdf():
|
217 |
+
kd_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
218 |
+
kd_ref = kd_cuda.clone().detach().requires_grad_(True)
|
219 |
+
arm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
220 |
+
arm_ref = arm_cuda.clone().detach().requires_grad_(True)
|
221 |
+
pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
222 |
+
pos_ref = pos_cuda.clone().detach().requires_grad_(True)
|
223 |
+
nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
224 |
+
nrm_ref = nrm_cuda.clone().detach().requires_grad_(True)
|
225 |
+
view_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
226 |
+
view_ref = view_cuda.clone().detach().requires_grad_(True)
|
227 |
+
light_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
228 |
+
light_ref = light_cuda.clone().detach().requires_grad_(True)
|
229 |
+
target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda')
|
230 |
+
|
231 |
+
ref = ru.pbr_bsdf(kd_ref, arm_ref, pos_ref, nrm_ref, view_ref, light_ref, use_python=True)
|
232 |
+
ref_loss = torch.nn.MSELoss()(ref, target)
|
233 |
+
ref_loss.backward()
|
234 |
+
|
235 |
+
cuda = ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda)
|
236 |
+
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
237 |
+
cuda_loss.backward()
|
238 |
+
|
239 |
+
print("-------------------------------------------------------------")
|
240 |
+
print(" Pbr BSDF")
|
241 |
+
print("-------------------------------------------------------------")
|
242 |
+
|
243 |
+
relative_loss("res:", ref, cuda)
|
244 |
+
if kd_ref.grad is not None:
|
245 |
+
relative_loss("kd:", kd_ref.grad, kd_cuda.grad)
|
246 |
+
if arm_ref.grad is not None:
|
247 |
+
relative_loss("arm:", arm_ref.grad, arm_cuda.grad)
|
248 |
+
if pos_ref.grad is not None:
|
249 |
+
relative_loss("pos:", pos_ref.grad, pos_cuda.grad)
|
250 |
+
if nrm_ref.grad is not None:
|
251 |
+
relative_loss("nrm:", nrm_ref.grad, nrm_cuda.grad)
|
252 |
+
if view_ref.grad is not None:
|
253 |
+
relative_loss("view:", view_ref.grad, view_cuda.grad)
|
254 |
+
if light_ref.grad is not None:
|
255 |
+
relative_loss("light:", light_ref.grad, light_cuda.grad)
|
256 |
+
|
257 |
+
test_normal()
|
258 |
+
|
259 |
+
test_schlick()
|
260 |
+
test_ndf_ggx()
|
261 |
+
test_lambda_ggx()
|
262 |
+
test_masking_smith()
|
263 |
+
|
264 |
+
test_lambert()
|
265 |
+
test_pbr_specular()
|
266 |
+
test_pbr_bsdf()
|
nvdiffmodeling/src/renderutils/tests/test_loss.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
import os
|
13 |
+
import sys
|
14 |
+
sys.path.insert(0, os.path.join(sys.path[0], '../..'))
|
15 |
+
import renderutils as ru
|
16 |
+
|
17 |
+
RES = 8
|
18 |
+
DTYPE = torch.float32
|
19 |
+
|
20 |
+
def tonemap_srgb(f):
|
21 |
+
return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f)
|
22 |
+
|
23 |
+
def l1(output, target):
|
24 |
+
x = torch.clamp(output, min=0, max=65535)
|
25 |
+
r = torch.clamp(target, min=0, max=65535)
|
26 |
+
x = tonemap_srgb(torch.log(x + 1))
|
27 |
+
r = tonemap_srgb(torch.log(r + 1))
|
28 |
+
return torch.nn.functional.l1_loss(x,r)
|
29 |
+
|
30 |
+
def relative_loss(name, ref, cuda):
|
31 |
+
ref = ref.float()
|
32 |
+
cuda = cuda.float()
|
33 |
+
print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item())
|
34 |
+
|
35 |
+
def test_loss(loss, tonemapper):
|
36 |
+
img_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
37 |
+
img_ref = img_cuda.clone().detach().requires_grad_(True)
|
38 |
+
target_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
39 |
+
target_ref = target_cuda.clone().detach().requires_grad_(True)
|
40 |
+
|
41 |
+
ref_loss = ru.image_loss(img_ref, target_ref, loss=loss, tonemapper=tonemapper, use_python=True)
|
42 |
+
ref_loss.backward()
|
43 |
+
|
44 |
+
cuda_loss = ru.image_loss(img_cuda, target_cuda, loss=loss, tonemapper=tonemapper)
|
45 |
+
cuda_loss.backward()
|
46 |
+
|
47 |
+
print("-------------------------------------------------------------")
|
48 |
+
print(" Loss: %s, %s" % (loss, tonemapper))
|
49 |
+
print("-------------------------------------------------------------")
|
50 |
+
|
51 |
+
relative_loss("res:", ref_loss, cuda_loss)
|
52 |
+
relative_loss("img:", img_ref.grad, img_cuda.grad)
|
53 |
+
relative_loss("target:", target_ref.grad, target_cuda.grad)
|
54 |
+
|
55 |
+
|
56 |
+
test_loss('l1', 'none')
|
57 |
+
test_loss('l1', 'log_srgb')
|
58 |
+
test_loss('mse', 'log_srgb')
|
59 |
+
test_loss('smape', 'none')
|
60 |
+
test_loss('relmse', 'none')
|
61 |
+
test_loss('mse', 'none')
|
nvdiffmodeling/src/renderutils/tests/test_mesh.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
import os
|
13 |
+
import sys
|
14 |
+
sys.path.insert(0, os.path.join(sys.path[0], '../..'))
|
15 |
+
import renderutils as ru
|
16 |
+
|
17 |
+
BATCH = 8
|
18 |
+
RES = 1024
|
19 |
+
DTYPE = torch.float32
|
20 |
+
|
21 |
+
torch.manual_seed(0)
|
22 |
+
|
23 |
+
def tonemap_srgb(f):
|
24 |
+
return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f)
|
25 |
+
|
26 |
+
def l1(output, target):
|
27 |
+
x = torch.clamp(output, min=0, max=65535)
|
28 |
+
r = torch.clamp(target, min=0, max=65535)
|
29 |
+
x = tonemap_srgb(torch.log(x + 1))
|
30 |
+
r = tonemap_srgb(torch.log(r + 1))
|
31 |
+
return torch.nn.functional.l1_loss(x,r)
|
32 |
+
|
33 |
+
def relative_loss(name, ref, cuda):
|
34 |
+
ref = ref.float()
|
35 |
+
cuda = cuda.float()
|
36 |
+
print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref)).item())
|
37 |
+
|
38 |
+
def test_xfm_points():
|
39 |
+
points_cuda = torch.rand(1, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
40 |
+
points_ref = points_cuda.clone().detach().requires_grad_(True)
|
41 |
+
mtx_cuda = torch.rand(BATCH, 4, 4, dtype=DTYPE, device='cuda', requires_grad=False)
|
42 |
+
mtx_ref = mtx_cuda.clone().detach().requires_grad_(True)
|
43 |
+
target = torch.rand(BATCH, RES, 4, dtype=DTYPE, device='cuda', requires_grad=True)
|
44 |
+
|
45 |
+
ref_out = ru.xfm_points(points_ref, mtx_ref, use_python=True)
|
46 |
+
ref_loss = torch.nn.MSELoss()(ref_out, target)
|
47 |
+
ref_loss.backward()
|
48 |
+
|
49 |
+
cuda_out = ru.xfm_points(points_cuda, mtx_cuda)
|
50 |
+
cuda_loss = torch.nn.MSELoss()(cuda_out, target)
|
51 |
+
cuda_loss.backward()
|
52 |
+
|
53 |
+
print("-------------------------------------------------------------")
|
54 |
+
|
55 |
+
relative_loss("res:", ref_out, cuda_out)
|
56 |
+
relative_loss("points:", points_ref.grad, points_cuda.grad)
|
57 |
+
|
58 |
+
def test_xfm_vectors():
|
59 |
+
points_cuda = torch.rand(1, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
60 |
+
points_ref = points_cuda.clone().detach().requires_grad_(True)
|
61 |
+
points_cuda_p = points_cuda.clone().detach().requires_grad_(True)
|
62 |
+
points_ref_p = points_cuda.clone().detach().requires_grad_(True)
|
63 |
+
mtx_cuda = torch.rand(BATCH, 4, 4, dtype=DTYPE, device='cuda', requires_grad=False)
|
64 |
+
mtx_ref = mtx_cuda.clone().detach().requires_grad_(True)
|
65 |
+
target = torch.rand(BATCH, RES, 4, dtype=DTYPE, device='cuda', requires_grad=True)
|
66 |
+
|
67 |
+
ref_out = ru.xfm_vectors(points_ref.contiguous(), mtx_ref, use_python=True)
|
68 |
+
ref_loss = torch.nn.MSELoss()(ref_out, target[..., 0:3])
|
69 |
+
ref_loss.backward()
|
70 |
+
|
71 |
+
cuda_out = ru.xfm_vectors(points_cuda.contiguous(), mtx_cuda)
|
72 |
+
cuda_loss = torch.nn.MSELoss()(cuda_out, target[..., 0:3])
|
73 |
+
cuda_loss.backward()
|
74 |
+
|
75 |
+
ref_out_p = ru.xfm_points(points_ref_p.contiguous(), mtx_ref, use_python=True)
|
76 |
+
ref_loss_p = torch.nn.MSELoss()(ref_out_p, target)
|
77 |
+
ref_loss_p.backward()
|
78 |
+
|
79 |
+
cuda_out_p = ru.xfm_points(points_cuda_p.contiguous(), mtx_cuda)
|
80 |
+
cuda_loss_p = torch.nn.MSELoss()(cuda_out_p, target)
|
81 |
+
cuda_loss_p.backward()
|
82 |
+
|
83 |
+
print("-------------------------------------------------------------")
|
84 |
+
|
85 |
+
relative_loss("res:", ref_out, cuda_out)
|
86 |
+
relative_loss("points:", points_ref.grad, points_cuda.grad)
|
87 |
+
relative_loss("points_p:", points_ref_p.grad, points_cuda_p.grad)
|
88 |
+
|
89 |
+
test_xfm_points()
|
90 |
+
test_xfm_vectors()
|
nvdiffmodeling/src/renderutils/tests/test_perf.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
import os
|
13 |
+
import sys
|
14 |
+
import time
|
15 |
+
sys.path.insert(0, os.path.join(sys.path[0], '../..'))
|
16 |
+
import renderutils as ru
|
17 |
+
|
18 |
+
DTYPE=torch.float32
|
19 |
+
|
20 |
+
def test_bsdf(BATCH, RES, ITR):
|
21 |
+
kd_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
22 |
+
kd_ref = kd_cuda.clone().detach().requires_grad_(True)
|
23 |
+
arm_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
24 |
+
arm_ref = arm_cuda.clone().detach().requires_grad_(True)
|
25 |
+
pos_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
26 |
+
pos_ref = pos_cuda.clone().detach().requires_grad_(True)
|
27 |
+
nrm_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
28 |
+
nrm_ref = nrm_cuda.clone().detach().requires_grad_(True)
|
29 |
+
view_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
30 |
+
view_ref = view_cuda.clone().detach().requires_grad_(True)
|
31 |
+
light_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
32 |
+
light_ref = light_cuda.clone().detach().requires_grad_(True)
|
33 |
+
target = torch.rand(BATCH, RES, RES, 3, device='cuda')
|
34 |
+
|
35 |
+
start = torch.cuda.Event(enable_timing=True)
|
36 |
+
end = torch.cuda.Event(enable_timing=True)
|
37 |
+
|
38 |
+
ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda)
|
39 |
+
|
40 |
+
print("--- Testing: [%d, %d, %d] ---" % (BATCH, RES, RES))
|
41 |
+
|
42 |
+
start.record()
|
43 |
+
for i in range(ITR):
|
44 |
+
ref = ru.pbr_bsdf(kd_ref, arm_ref, pos_ref, nrm_ref, view_ref, light_ref, use_python=True)
|
45 |
+
end.record()
|
46 |
+
torch.cuda.synchronize()
|
47 |
+
print("Pbr BSDF python:", start.elapsed_time(end))
|
48 |
+
|
49 |
+
start.record()
|
50 |
+
for i in range(ITR):
|
51 |
+
cuda = ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda)
|
52 |
+
end.record()
|
53 |
+
torch.cuda.synchronize()
|
54 |
+
print("Pbr BSDF cuda:", start.elapsed_time(end))
|
55 |
+
|
56 |
+
test_bsdf(1, 512, 1000)
|
57 |
+
test_bsdf(16, 512, 1000)
|
58 |
+
test_bsdf(1, 2048, 1000)
|
nvdiffmodeling/src/texture.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import os
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import nvdiffrast.torch as dr
|
13 |
+
|
14 |
+
from . import util
|
15 |
+
|
16 |
+
########################################################################################################
|
17 |
+
# Simple texture class. A texture can be either
|
18 |
+
# - A 3D tensor (using auto mipmaps)
|
19 |
+
# - A list of 3D tensors (full custom mip hierarchy)
|
20 |
+
########################################################################################################
|
21 |
+
|
22 |
+
class Texture2D:
|
23 |
+
# Initializes a texture from image data.
|
24 |
+
# Input can be constant value (1D array) or texture (3D array) or mip hierarchy (list of 3d arrays)
|
25 |
+
def __init__(self, init):
|
26 |
+
if isinstance(init, np.ndarray):
|
27 |
+
init = torch.tensor(init, dtype=torch.float32, device='cuda')
|
28 |
+
elif isinstance(init, list) and len(init) == 1:
|
29 |
+
init = init[0]
|
30 |
+
|
31 |
+
if isinstance(init, list) or len(init.shape) == 4:
|
32 |
+
self.data = init
|
33 |
+
elif len(init.shape) == 3:
|
34 |
+
self.data = init[None, ...]
|
35 |
+
else:
|
36 |
+
self.data = init[None, None, None, :] # Convert constant to 1x1 tensor
|
37 |
+
|
38 |
+
# Filtered (trilinear) sample texture at a given location
|
39 |
+
def sample(self, texc, texc_deriv, filter_mode='linear-mipmap-linear', data_fmt=torch.float32):
|
40 |
+
if isinstance(self.data, list):
|
41 |
+
out = dr.texture(self.data[0], texc, texc_deriv, mip=self.data[1:], filter_mode=filter_mode)
|
42 |
+
else:
|
43 |
+
out = dr.texture(self.data, texc, texc_deriv, filter_mode=filter_mode)
|
44 |
+
return out.to(data_fmt)
|
45 |
+
|
46 |
+
def getRes(self):
|
47 |
+
return self.getMips()[0].shape[1:3]
|
48 |
+
|
49 |
+
def getMips(self):
|
50 |
+
if isinstance(self.data, list):
|
51 |
+
return self.data
|
52 |
+
else:
|
53 |
+
return [self.data]
|
54 |
+
|
55 |
+
# In-place clamp with no derivative to make sure values are in valid range after training
|
56 |
+
def clamp_(self, min=None, max=None):
|
57 |
+
with torch.no_grad():
|
58 |
+
for mip in self.getMips():
|
59 |
+
mip.clamp_(min=min, max=max)
|
60 |
+
|
61 |
+
# In-place clamp with no derivative to make sure values are in valid range after training
|
62 |
+
def clamp_rgb_(self, minR=None, maxR=None, minG=None, maxG=None, minB=None, maxB=None):
|
63 |
+
with torch.no_grad():
|
64 |
+
for mip in self.getMips():
|
65 |
+
mip[...,0].clamp_(min=minR, max=maxR)
|
66 |
+
mip[...,1].clamp_(min=minG, max=maxG)
|
67 |
+
mip[...,2].clamp_(min=minB, max=maxB)
|
68 |
+
|
69 |
+
########################################################################################################
|
70 |
+
# Helper function to create a trainable texture from a regular texture. The trainable weights are
|
71 |
+
# initialized with texture data as an initial guess
|
72 |
+
########################################################################################################
|
73 |
+
|
74 |
+
def create_trainable(init, res, auto_mipmaps):
|
75 |
+
with torch.no_grad():
|
76 |
+
if isinstance(init, Texture2D):
|
77 |
+
assert isinstance(init.data, torch.Tensor)
|
78 |
+
init = init.data
|
79 |
+
elif isinstance(init, np.ndarray):
|
80 |
+
init = torch.tensor(init, dtype=torch.float32, device='cuda')
|
81 |
+
|
82 |
+
# Pad to NHWC if needed
|
83 |
+
if len(init.shape) == 1: # Extend constant to NHWC tensor
|
84 |
+
init = init[None, None, None, :]
|
85 |
+
elif len(init.shape) == 3:
|
86 |
+
init = init[None, ...]
|
87 |
+
|
88 |
+
# Scale input to desired resolution.
|
89 |
+
init = util.scale_img_nhwc(init, res)
|
90 |
+
|
91 |
+
# Genreate custom mipchain
|
92 |
+
if not auto_mipmaps:
|
93 |
+
mip_chain = [init.clone().detach().requires_grad_(True)]
|
94 |
+
while mip_chain[-1].shape[1] > 1 or mip_chain[-1].shape[2] > 1:
|
95 |
+
new_size = [max(mip_chain[-1].shape[1] // 2, 1), max(mip_chain[-1].shape[2] // 2, 1)]
|
96 |
+
init = util.scale_img_nhwc(mip_chain[-1], new_size)
|
97 |
+
mip_chain += [init.clone().detach().requires_grad_(True)]
|
98 |
+
return Texture2D(mip_chain)
|
99 |
+
else:
|
100 |
+
return Texture2D(init.clone().detach().requires_grad_(True))
|
101 |
+
|
102 |
+
########################################################################################################
|
103 |
+
# Convert texture to and from SRGB
|
104 |
+
########################################################################################################
|
105 |
+
|
106 |
+
def srgb_to_rgb(texture):
|
107 |
+
return Texture2D(list(util.srgb_to_rgb(mip) for mip in texture.getMips()))
|
108 |
+
|
109 |
+
def rgb_to_srgb(texture):
|
110 |
+
return Texture2D(list(util.rgb_to_srgb(mip) for mip in texture.getMips()))
|
111 |
+
|
112 |
+
########################################################################################################
|
113 |
+
# Utility functions for loading / storing a texture
|
114 |
+
########################################################################################################
|
115 |
+
|
116 |
+
def _load_mip2D(fn, lambda_fn=None, channels=None):
|
117 |
+
imgdata = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda')
|
118 |
+
if channels is not None:
|
119 |
+
imgdata = imgdata[..., 0:channels]
|
120 |
+
if lambda_fn is not None:
|
121 |
+
imgdata = lambda_fn(imgdata)
|
122 |
+
return imgdata.detach().clone()
|
123 |
+
|
124 |
+
def load_texture2D(fn, lambda_fn=None, channels=None):
|
125 |
+
base, ext = os.path.splitext(fn)
|
126 |
+
if os.path.exists(base + "_0" + ext):
|
127 |
+
mips = []
|
128 |
+
while os.path.exists(base + ("_%d" % len(mips)) + ext):
|
129 |
+
mips += [_load_mip2D(base + ("_%d" % len(mips)) + ext, lambda_fn, channels)]
|
130 |
+
return Texture2D(mips)
|
131 |
+
else:
|
132 |
+
return Texture2D(_load_mip2D(fn, lambda_fn, channels))
|
133 |
+
|
134 |
+
def _save_mip2D(fn, mip, mipidx, lambda_fn):
|
135 |
+
if lambda_fn is not None:
|
136 |
+
data = lambda_fn(mip).detach().cpu().numpy()
|
137 |
+
else:
|
138 |
+
data = mip.detach().cpu().numpy()
|
139 |
+
|
140 |
+
if mipidx is None:
|
141 |
+
util.save_image(fn, data)
|
142 |
+
else:
|
143 |
+
base, ext = os.path.splitext(fn)
|
144 |
+
util.save_image(base + ("_%d" % mipidx) + ext, data)
|
145 |
+
|
146 |
+
def save_texture2D(fn, tex, lambda_fn=None):
|
147 |
+
if isinstance(tex.data, list):
|
148 |
+
for i, mip in enumerate(tex.data):
|
149 |
+
_save_mip2D(fn, mip[0,...], i, lambda_fn)
|
150 |
+
else:
|
151 |
+
_save_mip2D(fn, tex.data[0,...], None, lambda_fn)
|
nvdiffmodeling/src/util.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import nvdiffrast.torch as dr
|
14 |
+
import imageio
|
15 |
+
|
16 |
+
#----------------------------------------------------------------------------
|
17 |
+
# Vector operations
|
18 |
+
#----------------------------------------------------------------------------
|
19 |
+
|
20 |
+
def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
21 |
+
return torch.sum(x*y, -1, keepdim=True)
|
22 |
+
|
23 |
+
def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor:
|
24 |
+
return 2*dot(x, n)*n - x
|
25 |
+
|
26 |
+
def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
|
27 |
+
return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN
|
28 |
+
|
29 |
+
def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
|
30 |
+
return x / length(x, eps)
|
31 |
+
|
32 |
+
def to_hvec(x: torch.Tensor, w: float) -> torch.Tensor:
|
33 |
+
return torch.nn.functional.pad(x, pad=(0,1), mode='constant', value=w)
|
34 |
+
|
35 |
+
#----------------------------------------------------------------------------
|
36 |
+
# Tonemapping
|
37 |
+
#----------------------------------------------------------------------------
|
38 |
+
|
39 |
+
def tonemap_srgb(f: torch.Tensor) -> torch.Tensor:
|
40 |
+
return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f)
|
41 |
+
|
42 |
+
#----------------------------------------------------------------------------
|
43 |
+
# sRGB color transforms
|
44 |
+
#----------------------------------------------------------------------------
|
45 |
+
|
46 |
+
def _rgb_to_srgb(f: torch.Tensor) -> torch.Tensor:
|
47 |
+
return torch.where(f <= 0.0031308, f * 12.92, torch.pow(torch.clamp(f, 0.0031308), 1.0/2.4)*1.055 - 0.055)
|
48 |
+
|
49 |
+
def rgb_to_srgb(f: torch.Tensor) -> torch.Tensor:
|
50 |
+
assert f.shape[-1] == 3 or f.shape[-1] == 4
|
51 |
+
out = torch.cat((_rgb_to_srgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _rgb_to_srgb(f)
|
52 |
+
assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2]
|
53 |
+
return out
|
54 |
+
|
55 |
+
def _srgb_to_rgb(f: torch.Tensor) -> torch.Tensor:
|
56 |
+
return torch.where(f <= 0.04045, f / 12.92, torch.pow((torch.clamp(f, 0.04045) + 0.055) / 1.055, 2.4))
|
57 |
+
|
58 |
+
def srgb_to_rgb(f: torch.Tensor) -> torch.Tensor:
|
59 |
+
assert f.shape[-1] == 3 or f.shape[-1] == 4
|
60 |
+
out = torch.cat((_srgb_to_rgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _srgb_to_rgb(f)
|
61 |
+
assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2]
|
62 |
+
return out
|
63 |
+
|
64 |
+
#----------------------------------------------------------------------------
|
65 |
+
# Displacement texture lookup
|
66 |
+
#----------------------------------------------------------------------------
|
67 |
+
|
68 |
+
def get_miplevels(texture: np.ndarray) -> float:
|
69 |
+
minDim = min(texture.shape[0], texture.shape[1])
|
70 |
+
return np.floor(np.log2(minDim))
|
71 |
+
|
72 |
+
# TODO: Handle wrapping maybe
|
73 |
+
def tex_2d(tex_map : torch.Tensor, coords : torch.Tensor, filter='nearest') -> torch.Tensor:
|
74 |
+
tex_map = tex_map[None, ...] # Add batch dimension
|
75 |
+
tex_map = tex_map.permute(0, 3, 1, 2) # NHWC -> NCHW
|
76 |
+
tex = torch.nn.functional.grid_sample(tex_map, coords[None, None, ...] * 2 - 1, mode=filter, align_corners=False)
|
77 |
+
tex = tex.permute(0, 2, 3, 1) # NCHW -> NHWC
|
78 |
+
return tex[0, 0, ...]
|
79 |
+
|
80 |
+
#----------------------------------------------------------------------------
|
81 |
+
# Image scaling
|
82 |
+
#----------------------------------------------------------------------------
|
83 |
+
|
84 |
+
def scale_img_hwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor:
|
85 |
+
return scale_img_nhwc(x[None, ...], size, mag, min)[0]
|
86 |
+
|
87 |
+
def scale_img_nhwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor:
|
88 |
+
assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other"
|
89 |
+
y = x.permute(0, 3, 1, 2) # NHWC -> NCHW
|
90 |
+
if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger
|
91 |
+
y = torch.nn.functional.interpolate(y, size, mode=min)
|
92 |
+
else: # Magnification
|
93 |
+
if mag == 'bilinear' or mag == 'bicubic':
|
94 |
+
y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True)
|
95 |
+
else:
|
96 |
+
y = torch.nn.functional.interpolate(y, size, mode=mag)
|
97 |
+
return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC
|
98 |
+
|
99 |
+
def avg_pool_nhwc(x : torch.Tensor, size) -> torch.Tensor:
|
100 |
+
y = x.permute(0, 3, 1, 2) # NHWC -> NCHW
|
101 |
+
y = torch.nn.functional.avg_pool2d(y, size)
|
102 |
+
return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC
|
103 |
+
|
104 |
+
#----------------------------------------------------------------------------
|
105 |
+
# Behaves similar to tf.segment_sum
|
106 |
+
#----------------------------------------------------------------------------
|
107 |
+
|
108 |
+
def segment_sum(data: torch.Tensor, segment_ids: torch.Tensor) -> torch.Tensor:
|
109 |
+
num_segments = torch.unique_consecutive(segment_ids).shape[0]
|
110 |
+
|
111 |
+
# Repeats ids until same dimension as data
|
112 |
+
if len(segment_ids.shape) == 1:
|
113 |
+
s = torch.prod(torch.tensor(data.shape[1:], dtype=torch.int64, device='cuda')).long()
|
114 |
+
segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:])
|
115 |
+
|
116 |
+
assert data.shape == segment_ids.shape, "data.shape and segment_ids.shape should be equal"
|
117 |
+
|
118 |
+
shape = [num_segments] + list(data.shape[1:])
|
119 |
+
result = torch.zeros(*shape, dtype=torch.float32, device='cuda')
|
120 |
+
result = result.scatter_add(0, segment_ids, data)
|
121 |
+
return result
|
122 |
+
|
123 |
+
#----------------------------------------------------------------------------
|
124 |
+
# Projection and transformation matrix helpers.
|
125 |
+
#----------------------------------------------------------------------------
|
126 |
+
|
127 |
+
def projection(x=0.1, n=1.0, f=50.0):
|
128 |
+
return np.array([[n/x, 0, 0, 0],
|
129 |
+
[ 0, n/-x, 0, 0],
|
130 |
+
[ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
|
131 |
+
[ 0, 0, -1, 0]]).astype(np.float32)
|
132 |
+
|
133 |
+
def translate(x, y, z):
|
134 |
+
return np.array([[1, 0, 0, x],
|
135 |
+
[0, 1, 0, y],
|
136 |
+
[0, 0, 1, z],
|
137 |
+
[0, 0, 0, 1]]).astype(np.float32)
|
138 |
+
|
139 |
+
def rotate_x(a):
|
140 |
+
s, c = np.sin(a), np.cos(a)
|
141 |
+
return np.array([[1, 0, 0, 0],
|
142 |
+
[0, c, s, 0],
|
143 |
+
[0, -s, c, 0],
|
144 |
+
[0, 0, 0, 1]]).astype(np.float32)
|
145 |
+
|
146 |
+
def rotate_y(a):
|
147 |
+
s, c = np.sin(a), np.cos(a)
|
148 |
+
return np.array([[ c, 0, s, 0],
|
149 |
+
[ 0, 1, 0, 0],
|
150 |
+
[-s, 0, c, 0],
|
151 |
+
[ 0, 0, 0, 1]]).astype(np.float32)
|
152 |
+
|
153 |
+
def scale(s):
|
154 |
+
return np.array([[ s, 0, 0, 0],
|
155 |
+
[ 0, s, 0, 0],
|
156 |
+
[ 0, 0, s, 0],
|
157 |
+
[ 0, 0, 0, 1]]).astype(np.float32)
|
158 |
+
|
159 |
+
def lookAt(eye, at, up):
|
160 |
+
a = eye - at
|
161 |
+
b = up
|
162 |
+
w = a / np.linalg.norm(a)
|
163 |
+
u = np.cross(b, w)
|
164 |
+
u = u / np.linalg.norm(u)
|
165 |
+
v = np.cross(w, u)
|
166 |
+
translate = np.array([[1, 0, 0, -eye[0]],
|
167 |
+
[0, 1, 0, -eye[1]],
|
168 |
+
[0, 0, 1, -eye[2]],
|
169 |
+
[0, 0, 0, 1]]).astype(np.float32)
|
170 |
+
rotate = np.array([[u[0], u[1], u[2], 0],
|
171 |
+
[v[0], v[1], v[2], 0],
|
172 |
+
[w[0], w[1], w[2], 0],
|
173 |
+
[0, 0, 0, 1]]).astype(np.float32)
|
174 |
+
return np.matmul(rotate, translate)
|
175 |
+
|
176 |
+
def random_rotation_translation(t):
|
177 |
+
m = np.random.normal(size=[3, 3])
|
178 |
+
m[1] = np.cross(m[0], m[2])
|
179 |
+
m[2] = np.cross(m[0], m[1])
|
180 |
+
m = m / np.linalg.norm(m, axis=1, keepdims=True)
|
181 |
+
m = np.pad(m, [[0, 1], [0, 1]], mode='constant')
|
182 |
+
m[3, 3] = 1.0
|
183 |
+
m[:3, 3] = np.random.uniform(-t, t, size=[3])
|
184 |
+
return m
|
185 |
+
|
186 |
+
|
187 |
+
#----------------------------------------------------------------------------
|
188 |
+
# Cosine sample around a vector N
|
189 |
+
#----------------------------------------------------------------------------
|
190 |
+
def cosine_sample(N : np.ndarray) -> np.ndarray:
|
191 |
+
# construct local frame
|
192 |
+
N = N/np.linalg.norm(N)
|
193 |
+
|
194 |
+
dx0 = np.array([0, N[2], -N[1]])
|
195 |
+
dx1 = np.array([-N[2], 0, N[0]])
|
196 |
+
|
197 |
+
dx = dx0 if np.dot(dx0,dx0) > np.dot(dx1,dx1) else dx1
|
198 |
+
dx = dx/np.linalg.norm(dx)
|
199 |
+
dy = np.cross(N,dx)
|
200 |
+
dy = dy/np.linalg.norm(dy)
|
201 |
+
|
202 |
+
# cosine sampling in local frame
|
203 |
+
phi = 2.0*np.pi*np.random.uniform()
|
204 |
+
s = np.random.uniform()
|
205 |
+
costheta = np.sqrt(s)
|
206 |
+
sintheta = np.sqrt(1.0 - s)
|
207 |
+
|
208 |
+
# cartesian vector in local space
|
209 |
+
x = np.cos(phi)*sintheta
|
210 |
+
y = np.sin(phi)*sintheta
|
211 |
+
z = costheta
|
212 |
+
|
213 |
+
# local to world
|
214 |
+
return dx*x + dy*y + N*z
|
215 |
+
|
216 |
+
|
217 |
+
#----------------------------------------------------------------------------
|
218 |
+
# Cosine sampled light directions around the vector N
|
219 |
+
#----------------------------------------------------------------------------
|
220 |
+
def cosine_sample_texture(res, N : np.ndarray) -> torch.Tensor:
|
221 |
+
# construct local frame
|
222 |
+
N = N/np.linalg.norm(N)
|
223 |
+
|
224 |
+
dx0 = np.array([0, N[2], -N[1]])
|
225 |
+
dx1 = np.array([-N[2], 0, N[0]])
|
226 |
+
|
227 |
+
dx = dx0 if np.dot(dx0,dx0) > np.dot(dx1,dx1) else dx1
|
228 |
+
dx = dx/np.linalg.norm(dx)
|
229 |
+
dy = np.cross(N,dx)
|
230 |
+
dy = dy/np.linalg.norm(dy)
|
231 |
+
|
232 |
+
X = torch.tensor(dx, dtype=torch.float32, device='cuda')
|
233 |
+
Y = torch.tensor(dy, dtype=torch.float32, device='cuda')
|
234 |
+
Z = torch.tensor(N, dtype=torch.float32, device='cuda')
|
235 |
+
|
236 |
+
# cosine sampling in local frame
|
237 |
+
|
238 |
+
phi = 2.0*np.pi*torch.rand(res, res, 1, dtype=torch.float32, device='cuda')
|
239 |
+
s = torch.rand(res, res, 1, dtype=torch.float32, device='cuda')
|
240 |
+
costheta = torch.sqrt(s)
|
241 |
+
sintheta = torch.sqrt(1.0 - s)
|
242 |
+
|
243 |
+
# cartesian vector in local space
|
244 |
+
x = torch.cos(phi)*sintheta
|
245 |
+
y = torch.sin(phi)*sintheta
|
246 |
+
z = costheta
|
247 |
+
|
248 |
+
# local to world
|
249 |
+
return X*x + Y*y + Z*z
|
250 |
+
|
251 |
+
#----------------------------------------------------------------------------
|
252 |
+
# Bilinear downsample by 2x.
|
253 |
+
#----------------------------------------------------------------------------
|
254 |
+
|
255 |
+
def bilinear_downsample(x : torch.tensor) -> torch.Tensor:
|
256 |
+
w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0
|
257 |
+
w = w.expand(x.shape[-1], 1, 4, 4)
|
258 |
+
x = torch.nn.functional.conv2d(x.permute(0, 3, 1, 2), w, padding=1, stride=2, groups=x.shape[-1])
|
259 |
+
return x.permute(0, 2, 3, 1)
|
260 |
+
|
261 |
+
#----------------------------------------------------------------------------
|
262 |
+
# Bilinear downsample log(spp) steps
|
263 |
+
#----------------------------------------------------------------------------
|
264 |
+
|
265 |
+
def bilinear_downsample(x : torch.tensor, spp) -> torch.Tensor:
|
266 |
+
w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0
|
267 |
+
g = x.shape[-1]
|
268 |
+
w = w.expand(g, 1, 4, 4)
|
269 |
+
x = x.permute(0, 3, 1, 2) # NHWC -> NCHW
|
270 |
+
steps = int(np.log2(spp))
|
271 |
+
for _ in range(steps):
|
272 |
+
xp = torch.nn.functional.pad(x, (1,1,1,1), mode='replicate')
|
273 |
+
x = torch.nn.functional.conv2d(xp, w, padding=0, stride=2, groups=g)
|
274 |
+
return x.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC
|
275 |
+
|
276 |
+
|
277 |
+
#----------------------------------------------------------------------------
|
278 |
+
# Image display function using OpenGL.
|
279 |
+
#----------------------------------------------------------------------------
|
280 |
+
|
281 |
+
_glfw_window = None
|
282 |
+
def display_image(image, zoom=None, size=None, title=None): # HWC
|
283 |
+
# Import OpenGL and glfw.
|
284 |
+
import OpenGL.GL as gl
|
285 |
+
import glfw
|
286 |
+
|
287 |
+
# Zoom image if requested.
|
288 |
+
image = np.asarray(image)
|
289 |
+
if size is not None:
|
290 |
+
assert zoom is None
|
291 |
+
zoom = max(1, size // image.shape[0])
|
292 |
+
if zoom is not None:
|
293 |
+
image = image.repeat(zoom, axis=0).repeat(zoom, axis=1)
|
294 |
+
height, width, channels = image.shape
|
295 |
+
|
296 |
+
# Initialize window.
|
297 |
+
if title is None:
|
298 |
+
title = 'Debug window'
|
299 |
+
global _glfw_window
|
300 |
+
if _glfw_window is None:
|
301 |
+
glfw.init()
|
302 |
+
_glfw_window = glfw.create_window(width, height, title, None, None)
|
303 |
+
glfw.make_context_current(_glfw_window)
|
304 |
+
glfw.show_window(_glfw_window)
|
305 |
+
glfw.swap_interval(0)
|
306 |
+
else:
|
307 |
+
glfw.make_context_current(_glfw_window)
|
308 |
+
glfw.set_window_title(_glfw_window, title)
|
309 |
+
glfw.set_window_size(_glfw_window, width, height)
|
310 |
+
|
311 |
+
# Update window.
|
312 |
+
glfw.poll_events()
|
313 |
+
gl.glClearColor(0, 0, 0, 1)
|
314 |
+
gl.glClear(gl.GL_COLOR_BUFFER_BIT)
|
315 |
+
gl.glWindowPos2f(0, 0)
|
316 |
+
gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
|
317 |
+
gl_format = {3: gl.GL_RGB, 2: gl.GL_RG, 1: gl.GL_LUMINANCE}[channels]
|
318 |
+
gl_dtype = {'uint8': gl.GL_UNSIGNED_BYTE, 'float32': gl.GL_FLOAT}[image.dtype.name]
|
319 |
+
gl.glDrawPixels(width, height, gl_format, gl_dtype, image[::-1])
|
320 |
+
glfw.swap_buffers(_glfw_window)
|
321 |
+
if glfw.window_should_close(_glfw_window):
|
322 |
+
return False
|
323 |
+
return True
|
324 |
+
|
325 |
+
#----------------------------------------------------------------------------
|
326 |
+
# Image save helper.
|
327 |
+
#----------------------------------------------------------------------------
|
328 |
+
|
329 |
+
def save_image(fn, x : np.ndarray) -> np.ndarray:
|
330 |
+
imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8))
|
331 |
+
|
332 |
+
def load_image(fn) -> np.ndarray:
|
333 |
+
img = imageio.imread(fn)
|
334 |
+
if img.dtype == np.float32: # HDR image
|
335 |
+
return img
|
336 |
+
else: # LDR image
|
337 |
+
return img.astype(np.float32) / 255
|
338 |
+
|
339 |
+
#----------------------------------------------------------------------------
|
340 |
+
|
341 |
+
def time_to_text(x):
|
342 |
+
if x > 3600:
|
343 |
+
return "%.2f h" % (x / 3600)
|
344 |
+
elif x > 60:
|
345 |
+
return "%.2f m" % (x / 60)
|
346 |
+
else:
|
347 |
+
return "%.2f s" % x
|
348 |
+
|
349 |
+
#----------------------------------------------------------------------------
|
350 |
+
|
351 |
+
def checkerboard(width, repetitions) -> np.ndarray:
|
352 |
+
tilesize = int(width//repetitions//2)
|
353 |
+
check = np.kron([[1, 0] * repetitions, [0, 1] * repetitions] * repetitions, np.ones((tilesize, tilesize)))*0.33 + 0.33
|
354 |
+
return np.stack((check, check, check), axis=-1)[None, ...]
|