Spaces:
Sleeping
Sleeping
Commit
·
5f58b04
1
Parent(s):
7589ff6
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +16 -0
- .gitignore +7 -0
- Dockerfile +56 -0
- README.md +4 -4
- app.py +80 -0
- app_3d.py +21 -0
- app_canny.py +83 -0
- app_matnet.py +83 -0
- app_sd.py +154 -0
- app_texnet.py +259 -0
- cv_utils.py +17 -0
- depth_estimator.py +25 -0
- environment.yml +7 -0
- examples/bunny/frame_0001.png +3 -0
- examples/bunny/mesh.obj +3 -0
- examples/bunny/uv_normal.png +3 -0
- examples/fighter/frame_0001.png +3 -0
- examples/fighter/mesh.obj +3 -0
- examples/fighter/uv_normal.png +3 -0
- examples/highheel/frame_0001.png +3 -0
- examples/highheel/mesh.obj +3 -0
- examples/highheel/uv_normal.png +3 -0
- examples/monkey/frame_0001.png +3 -0
- examples/monkey/mesh.obj +3 -0
- examples/monkey/uv_normal.png +3 -0
- examples/tank/frame_0001.png +3 -0
- examples/tank/mesh.obj +3 -0
- examples/tank/uv_normal.png +3 -0
- examples/tshirt/frame_0001.png +3 -0
- examples/tshirt/mesh.obj +3 -0
- examples/tshirt/uv_normal.png +3 -0
- image_segmentor.py +33 -0
- install.sh +18 -0
- model.py +959 -0
- preprocessor.py +120 -0
- push_dataset.py +9 -0
- rgb2x/generate_blend.py +142 -0
- rgb2x/gradio_demo_rgb2x.py +157 -0
- rgb2x/load_image.py +119 -0
- rgb2x/pipeline_rgb2x.py +821 -0
- run.sh +14 -0
- settings.py +23 -0
- text2tex/lib/__init__.py +0 -0
- text2tex/lib/camera_helper.py +231 -0
- text2tex/lib/constants.py +648 -0
- text2tex/lib/diffusion_helper.py +189 -0
- text2tex/lib/io_helper.py +78 -0
- text2tex/lib/mesh_helper.py +148 -0
- text2tex/lib/projection_helper.py +464 -0
- text2tex/lib/render_helper.py +108 -0
.gitattributes
CHANGED
@@ -33,3 +33,19 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
uv_normal.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
mesh.obj filter=lfs diff=lfs merge=lfs -text
|
38 |
+
frame_0001.png filter=lfs diff=lfs merge=lfs -textexamples/bunny/frame_0001.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
examples/bunny/uv_normal.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
examples/fighter/frame_0001.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
examples/fighter/uv_normal.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
examples/highheel/frame_0001.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
examples/highheel/uv_normal.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
examples/monkey/frame_0001.png filter=lfs diff=lfs merge=lfs -text
|
45 |
+
examples/monkey/uv_normal.png filter=lfs diff=lfs merge=lfs -text
|
46 |
+
examples/tank/frame_0001.png filter=lfs diff=lfs merge=lfs -text
|
47 |
+
examples/tank/uv_normal.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
examples/tshirt/frame_0001.png filter=lfs diff=lfs merge=lfs -text
|
49 |
+
examples/tshirt/mesh.obj filter=lfs diff=lfs merge=lfs -text
|
50 |
+
examples/tshirt/uv_normal.png filter=lfs diff=lfs merge=lfs -text
|
51 |
+
examples/bunny/frame_0001.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
data
|
3 |
+
# examples
|
4 |
+
.gradio
|
5 |
+
model_cache
|
6 |
+
output
|
7 |
+
test.png
|
Dockerfile
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM continuumio/anaconda3:main
|
2 |
+
|
3 |
+
# make sure cv2 can be loaded
|
4 |
+
RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
|
5 |
+
|
6 |
+
WORKDIR /code
|
7 |
+
COPY ./environment.yml /code/environment.yml
|
8 |
+
|
9 |
+
# Create the environment using the environment.yml file
|
10 |
+
RUN conda env create -f /code/environment.yml
|
11 |
+
|
12 |
+
# install pip packages to the gradio environment
|
13 |
+
# when adjusting the dockerfile on huggingface:
|
14 |
+
# - if the dockerfile is successfully compileds, a new space need to be initialized and push the changes accordingly
|
15 |
+
# - otherwise, you can commit to the failed build dockerfile for debugging
|
16 |
+
RUN conda run -n gradio pip install --upgrade pip
|
17 |
+
# RUN conda run -n gradio pip install diffusers["torch"] transformers accelerate xformers
|
18 |
+
# RUN conda run -n gradio pip install gradio
|
19 |
+
# RUN conda run -n gradio pip install controlnet-aux
|
20 |
+
# RUN conda install -n gradio pytorch3d -c pytorch3d -c conda-forge
|
21 |
+
# RUN conda install -n gradio -c conda-forge open-clip-torch pytorch-lightning
|
22 |
+
# RUN conda run -n gradio pip install trimesh xatlas scikit-learn opencv-python omegaconf
|
23 |
+
|
24 |
+
# Set the environment variable to use the gradio environment by default
|
25 |
+
# RUN echo "source activate gradio" > ~/.bashrc
|
26 |
+
# ENV PATH /opt/conda/envs/gradio/bin:$PATH
|
27 |
+
|
28 |
+
# Set up a new user named "user" with user ID 1000
|
29 |
+
RUN useradd -m -u 1000 user
|
30 |
+
# Switch to the "user" user
|
31 |
+
USER user
|
32 |
+
RUN conda create -n gradio-user python=3.11
|
33 |
+
RUN conda run -n gradio-user pip install --upgrade pip
|
34 |
+
# RUN conda install -n gradio-user pytorch3d=0.7.7 -c pytorch3d -c conda-forge
|
35 |
+
# RUN conda install -n gradio-user -c conda-forge open-clip-torch pytorch-lightning
|
36 |
+
RUN conda run -n gradio-user pip install diffusers transformers accelerate xformers controlnet-aux gradio spaces trimesh xatlas scikit-learn opencv-python matplotlib omegaconf
|
37 |
+
|
38 |
+
# Set home to the user's home directory
|
39 |
+
ENV HOME=/home/user \
|
40 |
+
PYTHONPATH=$HOME/app \
|
41 |
+
PYTHONUNBUFFERED=1 \
|
42 |
+
GRADIO_ALLOW_FLAGGING=never \
|
43 |
+
GRADIO_SERVER_NAME=0.0.0.0 \
|
44 |
+
GRADIO_THEME=huggingface \
|
45 |
+
SYSTEM=spaces
|
46 |
+
# Set the working directory to the user's home directory
|
47 |
+
WORKDIR $HOME/app
|
48 |
+
|
49 |
+
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
50 |
+
COPY --chown=user . $HOME/app
|
51 |
+
|
52 |
+
# download https://huggingface.co/lllyasviel/ControlNet/resolve/main/models/control_sd15_depth.pth?download=true to $HOME/app/text2tex/models/ControlNet/models/control_sd15_depth.pth
|
53 |
+
RUN mkdir -p $HOME/app/text2tex/models/ControlNet/models && \
|
54 |
+
wget -O $HOME/app/text2tex/models/ControlNet/models/control_sd15_depth.pth https://huggingface.co/lllyasviel/ControlNet/resolve/main/models/control_sd15_depth.pth?download=true
|
55 |
+
|
56 |
+
CMD ["./run.sh"]
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
-
title: Docker
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
---
|
|
|
1 |
---
|
2 |
+
title: Docker Test6
|
3 |
+
emoji: 👀
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: pink
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
---
|
app.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import sys
|
7 |
+
pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
|
8 |
+
version_str="".join([
|
9 |
+
f"py3{sys.version_info.minor}_cu",
|
10 |
+
torch.version.cuda.replace(".",""),
|
11 |
+
f"_pyt{pyt_version_str}"
|
12 |
+
])
|
13 |
+
print(f"Using version: {version_str}") # used to locate pytorch3d version in the requirements.txt for huggingface
|
14 |
+
|
15 |
+
|
16 |
+
from app_canny import create_demo as create_demo_canny
|
17 |
+
from app_texnet import create_demo as create_demo_texnet
|
18 |
+
|
19 |
+
from model import Model
|
20 |
+
from settings import ALLOW_CHANGING_BASE_MODEL, DEFAULT_MODEL_ID, SHOW_DUPLICATE_BUTTON
|
21 |
+
|
22 |
+
DESCRIPTION = "# Material Authoring Demo v0.3"
|
23 |
+
|
24 |
+
if not torch.cuda.is_available():
|
25 |
+
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p> Check if the 'CUDA_VISIBLE_DEVICES' are set incorrectly in settings.py"
|
26 |
+
|
27 |
+
# model = Model(base_model_id=DEFAULT_MODEL_ID, task_name="Canny")
|
28 |
+
model = Model(base_model_id=DEFAULT_MODEL_ID, task_name="texnet")
|
29 |
+
|
30 |
+
with gr.Blocks() as demo:
|
31 |
+
gr.Markdown(DESCRIPTION)
|
32 |
+
gr.DuplicateButton(
|
33 |
+
value="Duplicate Space for private use",
|
34 |
+
elem_id="duplicate-button",
|
35 |
+
visible=SHOW_DUPLICATE_BUTTON,
|
36 |
+
)
|
37 |
+
|
38 |
+
with gr.Tabs():
|
39 |
+
with gr.Tab("Texnet+Matnet"):
|
40 |
+
create_demo_texnet(model.process_texnet)
|
41 |
+
|
42 |
+
with gr.Accordion(label="Base model", open=False):
|
43 |
+
with gr.Row():
|
44 |
+
with gr.Column(scale=5):
|
45 |
+
current_base_model = gr.Text(label="Current base model")
|
46 |
+
with gr.Column(scale=1):
|
47 |
+
check_base_model_button = gr.Button("Check current base model")
|
48 |
+
with gr.Row():
|
49 |
+
with gr.Column(scale=5):
|
50 |
+
new_base_model_id = gr.Text(
|
51 |
+
label="New base model",
|
52 |
+
max_lines=1,
|
53 |
+
placeholder="stable-diffusion-v1-5/stable-diffusion-v1-5",
|
54 |
+
info="The base model must be compatible with Stable Diffusion v1.5.",
|
55 |
+
interactive=ALLOW_CHANGING_BASE_MODEL,
|
56 |
+
)
|
57 |
+
with gr.Column(scale=1):
|
58 |
+
change_base_model_button = gr.Button("Change base model", interactive=ALLOW_CHANGING_BASE_MODEL)
|
59 |
+
if not ALLOW_CHANGING_BASE_MODEL:
|
60 |
+
gr.Markdown(
|
61 |
+
"""The base model is not allowed to be changed in this Space so as not to slow down the demo, but it can be changed if you duplicate the Space."""
|
62 |
+
)
|
63 |
+
|
64 |
+
check_base_model_button.click(
|
65 |
+
fn=lambda: model.base_model_id,
|
66 |
+
outputs=current_base_model,
|
67 |
+
queue=False,
|
68 |
+
api_name="check_base_model",
|
69 |
+
)
|
70 |
+
gr.on(
|
71 |
+
triggers=[new_base_model_id.submit, change_base_model_button.click],
|
72 |
+
fn=model.set_base_model,
|
73 |
+
inputs=new_base_model_id,
|
74 |
+
outputs=current_base_model,
|
75 |
+
api_name=False,
|
76 |
+
concurrency_id="main",
|
77 |
+
)
|
78 |
+
|
79 |
+
if __name__ == "__main__":
|
80 |
+
demo.queue(max_size=20).launch()
|
app_3d.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
|
4 |
+
def load_mesh(mesh_file_name):
|
5 |
+
return mesh_file_name
|
6 |
+
|
7 |
+
demo = gr.Interface(
|
8 |
+
fn=load_mesh,
|
9 |
+
inputs=gr.Model3D(),
|
10 |
+
outputs=gr.Model3D(
|
11 |
+
clear_color=(255.0, 0.0, 0.0, 0.0), label="3D Model", display_mode="wireframe"),
|
12 |
+
examples=[
|
13 |
+
[os.path.join(os.path.dirname(__file__), "examples/bunny/mesh.obj")],
|
14 |
+
[os.path.join(os.path.dirname(__file__), "examples/monkey/mesh.obj")],
|
15 |
+
[os.path.join(os.path.dirname(__file__), "examples/Bunny.obj")],
|
16 |
+
],
|
17 |
+
cache_examples=True
|
18 |
+
)
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
demo.launch()
|
app_canny.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
from settings import (
|
6 |
+
DEFAULT_IMAGE_RESOLUTION,
|
7 |
+
DEFAULT_NUM_IMAGES,
|
8 |
+
MAX_IMAGE_RESOLUTION,
|
9 |
+
MAX_NUM_IMAGES,
|
10 |
+
MAX_SEED,
|
11 |
+
)
|
12 |
+
from utils import randomize_seed_fn
|
13 |
+
|
14 |
+
|
15 |
+
def create_demo(process):
|
16 |
+
with gr.Blocks() as demo:
|
17 |
+
with gr.Row():
|
18 |
+
with gr.Column():
|
19 |
+
image = gr.Image()
|
20 |
+
prompt = gr.Textbox(label="Prompt", submit_btn=True)
|
21 |
+
with gr.Accordion("Advanced options", open=False):
|
22 |
+
num_samples = gr.Slider(
|
23 |
+
label="Number of images", minimum=1, maximum=MAX_NUM_IMAGES, value=DEFAULT_NUM_IMAGES, step=1
|
24 |
+
)
|
25 |
+
image_resolution = gr.Slider(
|
26 |
+
label="Image resolution",
|
27 |
+
minimum=256,
|
28 |
+
maximum=MAX_IMAGE_RESOLUTION,
|
29 |
+
value=DEFAULT_IMAGE_RESOLUTION,
|
30 |
+
step=256,
|
31 |
+
)
|
32 |
+
canny_low_threshold = gr.Slider(
|
33 |
+
label="Canny low threshold", minimum=1, maximum=255, value=100, step=1
|
34 |
+
)
|
35 |
+
canny_high_threshold = gr.Slider(
|
36 |
+
label="Canny high threshold", minimum=1, maximum=255, value=200, step=1
|
37 |
+
)
|
38 |
+
num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=20, step=1)
|
39 |
+
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
40 |
+
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
41 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
42 |
+
a_prompt = gr.Textbox(label="Additional prompt", value="best quality, extremely detailed")
|
43 |
+
n_prompt = gr.Textbox(
|
44 |
+
label="Negative prompt",
|
45 |
+
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
|
46 |
+
)
|
47 |
+
with gr.Column():
|
48 |
+
result = gr.Gallery(label="Output", show_label=False, columns=2, object_fit="scale-down")
|
49 |
+
inputs = [
|
50 |
+
image,
|
51 |
+
prompt,
|
52 |
+
a_prompt,
|
53 |
+
n_prompt,
|
54 |
+
num_samples,
|
55 |
+
image_resolution,
|
56 |
+
num_steps,
|
57 |
+
guidance_scale,
|
58 |
+
seed,
|
59 |
+
canny_low_threshold,
|
60 |
+
canny_high_threshold,
|
61 |
+
]
|
62 |
+
prompt.submit(
|
63 |
+
fn=randomize_seed_fn,
|
64 |
+
inputs=[seed, randomize_seed],
|
65 |
+
outputs=seed,
|
66 |
+
queue=False,
|
67 |
+
api_name=False,
|
68 |
+
).then(
|
69 |
+
fn=process,
|
70 |
+
inputs=inputs,
|
71 |
+
outputs=result,
|
72 |
+
api_name="canny",
|
73 |
+
concurrency_id="main",
|
74 |
+
)
|
75 |
+
return demo
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
from model import Model
|
80 |
+
|
81 |
+
model = Model(task_name="Canny")
|
82 |
+
demo = create_demo(model.process_canny)
|
83 |
+
demo.queue().launch()
|
app_matnet.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
from settings import (
|
6 |
+
DEFAULT_IMAGE_RESOLUTION,
|
7 |
+
DEFAULT_NUM_IMAGES,
|
8 |
+
MAX_IMAGE_RESOLUTION,
|
9 |
+
MAX_NUM_IMAGES,
|
10 |
+
MAX_SEED,
|
11 |
+
)
|
12 |
+
from utils import randomize_seed_fn
|
13 |
+
|
14 |
+
|
15 |
+
def create_demo(process):
|
16 |
+
with gr.Blocks() as demo:
|
17 |
+
with gr.Row():
|
18 |
+
with gr.Column():
|
19 |
+
image = gr.Image()
|
20 |
+
prompt = gr.Textbox(label="Prompt", submit_btn=True)
|
21 |
+
with gr.Accordion("Advanced options", open=False):
|
22 |
+
num_samples = gr.Slider(
|
23 |
+
label="Number of images", minimum=1, maximum=MAX_NUM_IMAGES, value=DEFAULT_NUM_IMAGES, step=1
|
24 |
+
)
|
25 |
+
image_resolution = gr.Slider(
|
26 |
+
label="Image resolution",
|
27 |
+
minimum=256,
|
28 |
+
maximum=MAX_IMAGE_RESOLUTION,
|
29 |
+
value=DEFAULT_IMAGE_RESOLUTION,
|
30 |
+
step=256,
|
31 |
+
)
|
32 |
+
canny_low_threshold = gr.Slider(
|
33 |
+
label="Canny low threshold", minimum=1, maximum=255, value=100, step=1
|
34 |
+
)
|
35 |
+
canny_high_threshold = gr.Slider(
|
36 |
+
label="Canny high threshold", minimum=1, maximum=255, value=200, step=1
|
37 |
+
)
|
38 |
+
num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=20, step=1)
|
39 |
+
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
40 |
+
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
41 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
42 |
+
a_prompt = gr.Textbox(label="Additional prompt", value="best quality, extremely detailed")
|
43 |
+
n_prompt = gr.Textbox(
|
44 |
+
label="Negative prompt",
|
45 |
+
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
|
46 |
+
)
|
47 |
+
with gr.Column():
|
48 |
+
result = gr.Gallery(label="Output", show_label=False, columns=2, object_fit="scale-down")
|
49 |
+
inputs = [
|
50 |
+
image,
|
51 |
+
prompt,
|
52 |
+
a_prompt,
|
53 |
+
n_prompt,
|
54 |
+
num_samples,
|
55 |
+
image_resolution,
|
56 |
+
num_steps,
|
57 |
+
guidance_scale,
|
58 |
+
seed,
|
59 |
+
canny_low_threshold,
|
60 |
+
canny_high_threshold,
|
61 |
+
]
|
62 |
+
prompt.submit(
|
63 |
+
fn=randomize_seed_fn,
|
64 |
+
inputs=[seed, randomize_seed],
|
65 |
+
outputs=seed,
|
66 |
+
queue=False,
|
67 |
+
api_name=False,
|
68 |
+
).then(
|
69 |
+
fn=process,
|
70 |
+
inputs=inputs,
|
71 |
+
outputs=result,
|
72 |
+
api_name="canny",
|
73 |
+
concurrency_id="main",
|
74 |
+
)
|
75 |
+
return demo
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
from model import Model
|
80 |
+
|
81 |
+
model = Model(task_name="Canny")
|
82 |
+
demo = create_demo(model.process_canny)
|
83 |
+
demo.queue().launch()
|
app_sd.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
|
5 |
+
import spaces #[uncomment to use ZeroGPU]
|
6 |
+
from diffusers import DiffusionPipeline
|
7 |
+
import torch
|
8 |
+
|
9 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
+
model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
|
11 |
+
|
12 |
+
if torch.cuda.is_available():
|
13 |
+
torch_dtype = torch.float16
|
14 |
+
else:
|
15 |
+
torch_dtype = torch.float32
|
16 |
+
|
17 |
+
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
|
18 |
+
pipe = pipe.to(device)
|
19 |
+
|
20 |
+
MAX_SEED = np.iinfo(np.int32).max
|
21 |
+
MAX_IMAGE_SIZE = 1024
|
22 |
+
|
23 |
+
|
24 |
+
@spaces.GPU #[uncomment to use ZeroGPU]
|
25 |
+
def infer(
|
26 |
+
prompt,
|
27 |
+
negative_prompt,
|
28 |
+
seed,
|
29 |
+
randomize_seed,
|
30 |
+
width,
|
31 |
+
height,
|
32 |
+
guidance_scale,
|
33 |
+
num_inference_steps,
|
34 |
+
progress=gr.Progress(track_tqdm=True),
|
35 |
+
):
|
36 |
+
if randomize_seed:
|
37 |
+
seed = random.randint(0, MAX_SEED)
|
38 |
+
|
39 |
+
generator = torch.Generator().manual_seed(seed)
|
40 |
+
|
41 |
+
image = pipe(
|
42 |
+
prompt=prompt,
|
43 |
+
negative_prompt=negative_prompt,
|
44 |
+
guidance_scale=guidance_scale,
|
45 |
+
num_inference_steps=num_inference_steps,
|
46 |
+
width=width,
|
47 |
+
height=height,
|
48 |
+
generator=generator,
|
49 |
+
).images[0]
|
50 |
+
|
51 |
+
return image, seed
|
52 |
+
|
53 |
+
|
54 |
+
examples = [
|
55 |
+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
56 |
+
"An astronaut riding a green horse",
|
57 |
+
"A delicious ceviche cheesecake slice",
|
58 |
+
]
|
59 |
+
|
60 |
+
css = """
|
61 |
+
#col-container {
|
62 |
+
margin: 0 auto;
|
63 |
+
max-width: 640px;
|
64 |
+
}
|
65 |
+
"""
|
66 |
+
|
67 |
+
with gr.Blocks(css=css) as demo:
|
68 |
+
with gr.Column(elem_id="col-container"):
|
69 |
+
gr.Markdown(" # Text-to-Image Gradio Template")
|
70 |
+
|
71 |
+
with gr.Row():
|
72 |
+
prompt = gr.Text(
|
73 |
+
label="Prompt",
|
74 |
+
show_label=False,
|
75 |
+
max_lines=1,
|
76 |
+
placeholder="Enter your prompt",
|
77 |
+
container=False,
|
78 |
+
)
|
79 |
+
|
80 |
+
run_button = gr.Button("Run", scale=0, variant="primary")
|
81 |
+
|
82 |
+
result = gr.Image(label="Result", show_label=False)
|
83 |
+
|
84 |
+
with gr.Accordion("Advanced Settings", open=False):
|
85 |
+
negative_prompt = gr.Text(
|
86 |
+
label="Negative prompt",
|
87 |
+
max_lines=1,
|
88 |
+
placeholder="Enter a negative prompt",
|
89 |
+
visible=False,
|
90 |
+
)
|
91 |
+
|
92 |
+
seed = gr.Slider(
|
93 |
+
label="Seed",
|
94 |
+
minimum=0,
|
95 |
+
maximum=MAX_SEED,
|
96 |
+
step=1,
|
97 |
+
value=0,
|
98 |
+
)
|
99 |
+
|
100 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
101 |
+
|
102 |
+
with gr.Row():
|
103 |
+
width = gr.Slider(
|
104 |
+
label="Width",
|
105 |
+
minimum=256,
|
106 |
+
maximum=MAX_IMAGE_SIZE,
|
107 |
+
step=32,
|
108 |
+
value=1024, # Replace with defaults that work for your model
|
109 |
+
)
|
110 |
+
|
111 |
+
height = gr.Slider(
|
112 |
+
label="Height",
|
113 |
+
minimum=256,
|
114 |
+
maximum=MAX_IMAGE_SIZE,
|
115 |
+
step=32,
|
116 |
+
value=1024, # Replace with defaults that work for your model
|
117 |
+
)
|
118 |
+
|
119 |
+
with gr.Row():
|
120 |
+
guidance_scale = gr.Slider(
|
121 |
+
label="Guidance scale",
|
122 |
+
minimum=0.0,
|
123 |
+
maximum=10.0,
|
124 |
+
step=0.1,
|
125 |
+
value=0.0, # Replace with defaults that work for your model
|
126 |
+
)
|
127 |
+
|
128 |
+
num_inference_steps = gr.Slider(
|
129 |
+
label="Number of inference steps",
|
130 |
+
minimum=1,
|
131 |
+
maximum=50,
|
132 |
+
step=1,
|
133 |
+
value=2, # Replace with defaults that work for your model
|
134 |
+
)
|
135 |
+
|
136 |
+
gr.Examples(examples=examples, inputs=[prompt])
|
137 |
+
gr.on(
|
138 |
+
triggers=[run_button.click, prompt.submit],
|
139 |
+
fn=infer,
|
140 |
+
inputs=[
|
141 |
+
prompt,
|
142 |
+
negative_prompt,
|
143 |
+
seed,
|
144 |
+
randomize_seed,
|
145 |
+
width,
|
146 |
+
height,
|
147 |
+
guidance_scale,
|
148 |
+
num_inference_steps,
|
149 |
+
],
|
150 |
+
outputs=[result, seed],
|
151 |
+
)
|
152 |
+
|
153 |
+
if __name__ == "__main__":
|
154 |
+
demo.launch()
|
app_texnet.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
import tempfile
|
6 |
+
import gradio as gr
|
7 |
+
from PIL import Image
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from settings import (
|
11 |
+
DEFAULT_IMAGE_RESOLUTION,
|
12 |
+
DEFAULT_NUM_IMAGES,
|
13 |
+
MAX_IMAGE_RESOLUTION,
|
14 |
+
MAX_NUM_IMAGES,
|
15 |
+
MAX_SEED,
|
16 |
+
)
|
17 |
+
from utils import randomize_seed_fn
|
18 |
+
|
19 |
+
# ---- helper to build a quick textured copy of the mesh ---------------
|
20 |
+
def apply_texture(src_mesh:str, texture:str, tag:str)->str:
|
21 |
+
"""
|
22 |
+
Writes a copy of `src_mesh` and tiny .mtl that points to `texture`.
|
23 |
+
Returns the new OBJ/GLB path for viewing.
|
24 |
+
"""
|
25 |
+
tmp_dir = tempfile.mkdtemp()
|
26 |
+
mesh_copy = os.path.join(tmp_dir, f"{tag}.obj")
|
27 |
+
mtl_name = f"{tag}.mtl"
|
28 |
+
|
29 |
+
# copy geometry
|
30 |
+
shutil.copy(src_mesh, mesh_copy)
|
31 |
+
|
32 |
+
# write minimal MTL
|
33 |
+
with open(os.path.join(tmp_dir, mtl_name), "w") as f:
|
34 |
+
f.write(f"newmtl material_0\nmap_Kd {os.path.basename(texture)}\n")
|
35 |
+
|
36 |
+
# ensure texture lives next to OBJ
|
37 |
+
shutil.copy(texture, os.path.join(tmp_dir, os.path.basename(texture)))
|
38 |
+
|
39 |
+
# patch OBJ to reference our new MTL
|
40 |
+
with open(mesh_copy, "r+") as f:
|
41 |
+
lines = f.readlines()
|
42 |
+
if not lines[0].startswith("mtllib"):
|
43 |
+
lines.insert(0, f"mtllib {mtl_name}\n")
|
44 |
+
f.seek(0); f.writelines(lines)
|
45 |
+
|
46 |
+
return mesh_copy
|
47 |
+
|
48 |
+
def image_to_temp_path(img_like, tag, out_dir=None):
|
49 |
+
"""
|
50 |
+
Convert various image-like objects (str, PIL.Image, list, tuple) to temp PNG path.
|
51 |
+
Returns the path to the saved image file.
|
52 |
+
"""
|
53 |
+
# Handle tuple or list input
|
54 |
+
if isinstance(img_like, (list, tuple)):
|
55 |
+
if len(img_like) == 0:
|
56 |
+
raise ValueError("Empty image list/tuple.")
|
57 |
+
img_like = img_like[0]
|
58 |
+
|
59 |
+
# If it's already a file path
|
60 |
+
if isinstance(img_like, str):
|
61 |
+
return img_like
|
62 |
+
|
63 |
+
# If it's a PIL Image
|
64 |
+
if isinstance(img_like, Image.Image):
|
65 |
+
temp_path = os.path.join(tempfile.mkdtemp() if out_dir is None else out_dir, f"{tag}.png")
|
66 |
+
os.makedirs(os.path.dirname(temp_path), exist_ok=True)
|
67 |
+
img_like.save(temp_path)
|
68 |
+
return temp_path
|
69 |
+
|
70 |
+
# if it's numpy array
|
71 |
+
if isinstance(img_like, np.ndarray):
|
72 |
+
temp_path = os.path.join(tempfile.mkdtemp() if out_dir is None else out_dir, f"{tag}.png")
|
73 |
+
os.makedirs(os.path.dirname(temp_path), exist_ok=True)
|
74 |
+
img_like = Image.fromarray(img_like)
|
75 |
+
img_like.save(temp_path)
|
76 |
+
return temp_path
|
77 |
+
|
78 |
+
raise ValueError(f"Expected PIL.Image, str, list, or tuple — got {type(img_like)}")
|
79 |
+
|
80 |
+
def show_mesh(which, mesh, inp, coarse, fine):
|
81 |
+
"""Switch the displayed texture based on dropdown change."""
|
82 |
+
print()
|
83 |
+
tex_map = {
|
84 |
+
"Input": image_to_temp_path(inp, "input"),
|
85 |
+
"Coarse": coarse[0] if isinstance(coarse, tuple) else coarse,
|
86 |
+
"Fine": fine[0] if isinstance(fine, tuple) else fine,
|
87 |
+
}
|
88 |
+
texture_path = tex_map[which]
|
89 |
+
return apply_texture(mesh, texture_path, which.lower())
|
90 |
+
# ----------------------------------------------------------------------
|
91 |
+
|
92 |
+
|
93 |
+
def create_demo(process):
|
94 |
+
with gr.Blocks() as demo:
|
95 |
+
with gr.Row():
|
96 |
+
with gr.Column():
|
97 |
+
gr.Markdown("## Select preset from the example list, and modify the prompt accordingly")
|
98 |
+
with gr.Row():
|
99 |
+
name = gr.Textbox(label="Name", interactive=False, visible=False)
|
100 |
+
representative = gr.Image(label="Geometry", interactive=False)
|
101 |
+
image = gr.Image(label="UV Normal", interactive=False)
|
102 |
+
prompt = gr.Textbox(label="Prompt", submit_btn=True)
|
103 |
+
with gr.Accordion("Advanced options", open=False):
|
104 |
+
num_samples = gr.Slider(
|
105 |
+
label="Number of images", minimum=1, maximum=MAX_NUM_IMAGES, value=DEFAULT_NUM_IMAGES, step=1
|
106 |
+
)
|
107 |
+
image_resolution = gr.Slider(
|
108 |
+
label="Image resolution",
|
109 |
+
minimum=256,
|
110 |
+
maximum=MAX_IMAGE_RESOLUTION,
|
111 |
+
value=DEFAULT_IMAGE_RESOLUTION,
|
112 |
+
step=256,
|
113 |
+
)
|
114 |
+
num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=10, step=1)
|
115 |
+
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
116 |
+
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
117 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
118 |
+
a_prompt = gr.Textbox(label="Additional prompt", value="best quality, extremely detailed")
|
119 |
+
n_prompt = gr.Textbox(
|
120 |
+
label="Negative prompt",
|
121 |
+
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
|
122 |
+
)
|
123 |
+
with gr.Column():
|
124 |
+
# 2x2 grid of images for the output textures
|
125 |
+
gr.Markdown("### Output BRDF")
|
126 |
+
with gr.Row():
|
127 |
+
base_color = gr.Gallery(label="Base Color", show_label=True, columns=1, object_fit="scale-down")
|
128 |
+
normal = gr.Gallery(label="Displacement Map", show_label=True, columns=1, object_fit="scale-down")
|
129 |
+
with gr.Row():
|
130 |
+
roughness = gr.Gallery(label="Roughness Map", show_label=True, columns=1, object_fit="scale-down")
|
131 |
+
metallic = gr.Gallery(label="Metallic Map", show_label=True, columns=1, object_fit="scale-down")
|
132 |
+
|
133 |
+
gr.Markdown("### Download Packed Blender Files for 3D Visualization")
|
134 |
+
out_blender_path = gr.File(label="Generated Blender File", file_types=[".blend"])
|
135 |
+
|
136 |
+
inputs = [
|
137 |
+
name, # Name of the object
|
138 |
+
representative, # Geometry mesh
|
139 |
+
image,
|
140 |
+
prompt,
|
141 |
+
a_prompt,
|
142 |
+
n_prompt,
|
143 |
+
num_samples,
|
144 |
+
image_resolution,
|
145 |
+
num_steps,
|
146 |
+
guidance_scale,
|
147 |
+
seed,
|
148 |
+
]
|
149 |
+
|
150 |
+
# first call → run diffusion / texture network
|
151 |
+
prompt.submit(
|
152 |
+
fn=randomize_seed_fn,
|
153 |
+
inputs=[seed, randomize_seed],
|
154 |
+
outputs=seed,
|
155 |
+
queue=False,
|
156 |
+
api_name=False,
|
157 |
+
).then(
|
158 |
+
fn=process,
|
159 |
+
inputs=inputs,
|
160 |
+
outputs=[base_color, normal, roughness, metallic, out_blender_path],
|
161 |
+
api_name="canny",
|
162 |
+
concurrency_id="main",
|
163 |
+
)
|
164 |
+
|
165 |
+
gr.Examples(
|
166 |
+
fn=process,
|
167 |
+
inputs=inputs,
|
168 |
+
outputs=[base_color, normal, roughness, metallic],
|
169 |
+
examples=[
|
170 |
+
[
|
171 |
+
"bunny",
|
172 |
+
"examples/bunny/frame_0001.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/bunny/uv_normal/fused.png
|
173 |
+
"examples/bunny/uv_normal.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/bunny/uv_normal/fused.png
|
174 |
+
"feather",
|
175 |
+
a_prompt.value,
|
176 |
+
n_prompt.value,
|
177 |
+
num_samples.value,
|
178 |
+
image_resolution.value,
|
179 |
+
num_steps.value,
|
180 |
+
guidance_scale.value,
|
181 |
+
seed.value,
|
182 |
+
],
|
183 |
+
[
|
184 |
+
"monkey",
|
185 |
+
"examples/monkey/frame_0001.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
186 |
+
"examples/monkey/uv_normal.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
187 |
+
"wood",
|
188 |
+
a_prompt.value,
|
189 |
+
n_prompt.value,
|
190 |
+
num_samples.value,
|
191 |
+
image_resolution.value,
|
192 |
+
num_steps.value,
|
193 |
+
guidance_scale.value,
|
194 |
+
seed.value,
|
195 |
+
],
|
196 |
+
[
|
197 |
+
"tshirt",
|
198 |
+
"examples/tshirt/frame_0001.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
199 |
+
"examples/tshirt/uv_normal.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
200 |
+
"wood",
|
201 |
+
a_prompt.value,
|
202 |
+
n_prompt.value,
|
203 |
+
num_samples.value,
|
204 |
+
image_resolution.value,
|
205 |
+
num_steps.value,
|
206 |
+
guidance_scale.value,
|
207 |
+
seed.value,
|
208 |
+
],
|
209 |
+
# [
|
210 |
+
# "highheel",
|
211 |
+
# "examples/highheel/frame_0001.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
212 |
+
# "examples/highheel/uv_normal.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
213 |
+
# "wood",
|
214 |
+
# a_prompt.value,
|
215 |
+
# n_prompt.value,
|
216 |
+
# num_samples.value,
|
217 |
+
# image_resolution.value,
|
218 |
+
# num_steps.value,
|
219 |
+
# guidance_scale.value,
|
220 |
+
# seed.value,
|
221 |
+
# ],
|
222 |
+
[
|
223 |
+
"tank",
|
224 |
+
"examples/tank/frame_0001.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
225 |
+
"examples/tank/uv_normal.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
226 |
+
"wood",
|
227 |
+
a_prompt.value,
|
228 |
+
n_prompt.value,
|
229 |
+
num_samples.value,
|
230 |
+
image_resolution.value,
|
231 |
+
num_steps.value,
|
232 |
+
guidance_scale.value,
|
233 |
+
seed.value,
|
234 |
+
],
|
235 |
+
[
|
236 |
+
"fighter",
|
237 |
+
"examples/fighter/frame_0001.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
238 |
+
"examples/fighter/uv_normal.png", # /dgxusers/Users/jyang/project/ObjectReal/data/control/preprocess/monkey/uv_normal/fused.png
|
239 |
+
"wood",
|
240 |
+
a_prompt.value,
|
241 |
+
n_prompt.value,
|
242 |
+
num_samples.value,
|
243 |
+
image_resolution.value,
|
244 |
+
num_steps.value,
|
245 |
+
guidance_scale.value,
|
246 |
+
seed.value,
|
247 |
+
],
|
248 |
+
],
|
249 |
+
)
|
250 |
+
|
251 |
+
return demo
|
252 |
+
|
253 |
+
|
254 |
+
if __name__ == "__main__":
|
255 |
+
from model import Model
|
256 |
+
|
257 |
+
model = Model(task_name="Texnet")
|
258 |
+
demo = create_demo(model.process_texnet)
|
259 |
+
demo.queue().launch()
|
cv_utils.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def resize_image(input_image, resolution, interpolation=None):
|
6 |
+
H, W, C = input_image.shape
|
7 |
+
H = float(H)
|
8 |
+
W = float(W)
|
9 |
+
k = float(resolution) / max(H, W)
|
10 |
+
H *= k
|
11 |
+
W *= k
|
12 |
+
H = int(np.round(H / 64.0)) * 64
|
13 |
+
W = int(np.round(W / 64.0)) * 64
|
14 |
+
if interpolation is None:
|
15 |
+
interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA
|
16 |
+
img = cv2.resize(input_image, (W, H), interpolation=interpolation)
|
17 |
+
return img
|
depth_estimator.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import PIL.Image
|
3 |
+
from controlnet_aux.util import HWC3
|
4 |
+
from transformers import pipeline
|
5 |
+
|
6 |
+
from cv_utils import resize_image
|
7 |
+
|
8 |
+
|
9 |
+
class DepthEstimator:
|
10 |
+
def __init__(self):
|
11 |
+
self.model = pipeline("depth-estimation")
|
12 |
+
|
13 |
+
def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
|
14 |
+
detect_resolution = kwargs.pop("detect_resolution", 512)
|
15 |
+
image_resolution = kwargs.pop("image_resolution", 512)
|
16 |
+
image = np.array(image)
|
17 |
+
image = HWC3(image)
|
18 |
+
image = resize_image(image, resolution=detect_resolution)
|
19 |
+
image = PIL.Image.fromarray(image)
|
20 |
+
image = self.model(image)
|
21 |
+
image = image["depth"]
|
22 |
+
image = np.array(image)
|
23 |
+
image = HWC3(image)
|
24 |
+
image = resize_image(image, resolution=image_resolution)
|
25 |
+
return PIL.Image.fromarray(image)
|
environment.yml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: gradio
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- python=3.11
|
7 |
+
- gradio
|
examples/bunny/frame_0001.png
ADDED
![]() |
Git LFS Details
|
examples/bunny/mesh.obj
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7b6262e2b5563901d38599a08926ac57449b7b6c0c42a0c9b724154cde282799
|
3 |
+
size 6044863
|
examples/bunny/uv_normal.png
ADDED
![]() |
Git LFS Details
|
examples/fighter/frame_0001.png
ADDED
![]() |
Git LFS Details
|
examples/fighter/mesh.obj
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:04c809af1ea9dadbea30261e0a8eef6b13735969e6b9e7d4e7423950072bc095
|
3 |
+
size 1576167
|
examples/fighter/uv_normal.png
ADDED
![]() |
Git LFS Details
|
examples/highheel/frame_0001.png
ADDED
![]() |
Git LFS Details
|
examples/highheel/mesh.obj
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1cb43ff640727c280221148c4047b9c3df2da6033b421a3cfa6d729848a128d7
|
3 |
+
size 8394487
|
examples/highheel/uv_normal.png
ADDED
![]() |
Git LFS Details
|
examples/monkey/frame_0001.png
ADDED
![]() |
Git LFS Details
|
examples/monkey/mesh.obj
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6a49e7eef70eb55f7de5eab9615e981194db4cc0b1195bc8270d833aaa6047ac
|
3 |
+
size 6601492
|
examples/monkey/uv_normal.png
ADDED
![]() |
Git LFS Details
|
examples/tank/frame_0001.png
ADDED
![]() |
Git LFS Details
|
examples/tank/mesh.obj
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:301633de1a7757f78a6f67abb6e61bcc8e6a01f5a54a8582d1943ad0ad943211
|
3 |
+
size 6942253
|
examples/tank/uv_normal.png
ADDED
![]() |
Git LFS Details
|
examples/tshirt/frame_0001.png
ADDED
![]() |
Git LFS Details
|
examples/tshirt/mesh.obj
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b7c6c9bdec8d646a1980e5b987a1182c92af84cc945ef49c1735d4337185d3e5
|
3 |
+
size 39275876
|
examples/tshirt/uv_normal.png
ADDED
![]() |
Git LFS Details
|
image_segmentor.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import PIL.Image
|
4 |
+
import torch
|
5 |
+
from controlnet_aux.util import HWC3, ade_palette
|
6 |
+
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
|
7 |
+
|
8 |
+
from cv_utils import resize_image
|
9 |
+
|
10 |
+
|
11 |
+
class ImageSegmentor:
|
12 |
+
def __init__(self):
|
13 |
+
self.image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
|
14 |
+
self.image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
|
15 |
+
|
16 |
+
@torch.inference_mode()
|
17 |
+
def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
|
18 |
+
detect_resolution = kwargs.pop("detect_resolution", 512)
|
19 |
+
image_resolution = kwargs.pop("image_resolution", 512)
|
20 |
+
image = HWC3(image)
|
21 |
+
image = resize_image(image, resolution=detect_resolution)
|
22 |
+
image = PIL.Image.fromarray(image)
|
23 |
+
|
24 |
+
pixel_values = self.image_processor(image, return_tensors="pt").pixel_values
|
25 |
+
outputs = self.image_segmentor(pixel_values)
|
26 |
+
seg = self.image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
27 |
+
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
28 |
+
for label, color in enumerate(ade_palette()):
|
29 |
+
color_seg[seg == label, :] = color
|
30 |
+
color_seg = color_seg.astype(np.uint8)
|
31 |
+
|
32 |
+
color_seg = resize_image(color_seg, resolution=image_resolution, interpolation=cv2.INTER_NEAREST)
|
33 |
+
return PIL.Image.fromarray(color_seg)
|
install.sh
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
eval "$(conda shell.bash hook)"
|
3 |
+
# conda activate base
|
4 |
+
# conda remove -n matgen-plus --all
|
5 |
+
|
6 |
+
conda create -n matgen-plus python=3.11
|
7 |
+
conda activate matgen-plus
|
8 |
+
|
9 |
+
pip install diffusers["torch"] transformers accelerate xformers
|
10 |
+
pip install gradio
|
11 |
+
pip install controlnet-aux
|
12 |
+
|
13 |
+
# text2tex
|
14 |
+
conda install pytorch3d -c pytorch -c conda-forge
|
15 |
+
conda install -c conda-forge open-clip-torch pytorch-lightning
|
16 |
+
pip install trimesh xatlas scikit-learn opencv-python omegaconf
|
17 |
+
|
18 |
+
python app.py
|
model.py
ADDED
@@ -0,0 +1,959 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
|
3 |
+
# get socket and check if the name is vgldgx01
|
4 |
+
import socket
|
5 |
+
if socket.gethostname() != "vgldgx01":
|
6 |
+
import spaces #[uncomment to use ZeroGPU]
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import PIL.Image
|
10 |
+
import torch
|
11 |
+
from controlnet_aux.util import HWC3
|
12 |
+
from diffusers import (
|
13 |
+
ControlNetModel,
|
14 |
+
DiffusionPipeline,
|
15 |
+
StableDiffusionControlNetPipeline,
|
16 |
+
StableDiffusionImg2ImgPipeline,
|
17 |
+
UniPCMultistepScheduler,
|
18 |
+
DDIMScheduler, #rgb2x
|
19 |
+
)
|
20 |
+
import torchvision
|
21 |
+
from torchvision import transforms
|
22 |
+
from cv_utils import resize_image
|
23 |
+
from preprocessor import Preprocessor
|
24 |
+
from settings import MAX_IMAGE_RESOLUTION, MAX_NUM_IMAGES
|
25 |
+
from tqdm.auto import tqdm
|
26 |
+
import subprocess
|
27 |
+
|
28 |
+
from rgb2x.pipeline_rgb2x import StableDiffusionAOVMatEstPipeline
|
29 |
+
from app_texnet import image_to_temp_path
|
30 |
+
import os
|
31 |
+
import time
|
32 |
+
import tempfile
|
33 |
+
from text2tex.scripts.generate_texture import text2tex_call, init_args
|
34 |
+
from glob import glob
|
35 |
+
|
36 |
+
CONTROLNET_MODEL_IDS = {
|
37 |
+
# "Openpose": "lllyasviel/control_v11p_sd15_openpose",
|
38 |
+
# "Canny": "lllyasviel/control_v11p_sd15_canny",
|
39 |
+
# "MLSD": "lllyasviel/control_v11p_sd15_mlsd",
|
40 |
+
# "scribble": "lllyasviel/control_v11p_sd15_scribble",
|
41 |
+
# "softedge": "lllyasviel/control_v11p_sd15_softedge",
|
42 |
+
# "segmentation": "lllyasviel/control_v11p_sd15_seg",
|
43 |
+
# "depth": "lllyasviel/control_v11f1p_sd15_depth",
|
44 |
+
# "NormalBae": "lllyasviel/control_v11p_sd15_normalbae",
|
45 |
+
# "lineart": "lllyasviel/control_v11p_sd15_lineart",
|
46 |
+
# "lineart_anime": "lllyasviel/control_v11p_sd15s2_lineart_anime",
|
47 |
+
# "shuffle": "lllyasviel/control_v11e_sd15_shuffle",
|
48 |
+
# "ip2p": "lllyasviel/control_v11e_sd15_ip2p",
|
49 |
+
# "inpaint": "lllyasviel/control_v11e_sd15_inpaint",
|
50 |
+
# "texnet": "/home/jyang/projects/ObjectReal/logs/train_texnet_deploy/checkpoint-55000/controlnet" # load and call
|
51 |
+
"texnet": "jingyangcarl/texnet",
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
def download_all_controlnet_weights() -> None:
|
56 |
+
for model_id in CONTROLNET_MODEL_IDS.values():
|
57 |
+
ControlNetModel.from_pretrained(model_id)
|
58 |
+
|
59 |
+
|
60 |
+
class Model:
|
61 |
+
def __init__(
|
62 |
+
self, base_model_id: str = "stable-diffusion-v1-5/stable-diffusion-v1-5", task_name: str = "Canny"
|
63 |
+
) -> None:
|
64 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
65 |
+
self.base_model_id = ""
|
66 |
+
self.task_name = ""
|
67 |
+
self.pipe = self.load_pipe(base_model_id, task_name)
|
68 |
+
self.pipe_base = StableDiffusionImg2ImgPipeline.from_pretrained(
|
69 |
+
'runwayml/stable-diffusion-v1-5', safety_checker=None, torch_dtype=torch.float16
|
70 |
+
).to(self.device)
|
71 |
+
self.preprocessor = Preprocessor()
|
72 |
+
|
73 |
+
# set up pipe_rgb2x
|
74 |
+
self.pipe_rgb2x = StableDiffusionAOVMatEstPipeline.from_pretrained(
|
75 |
+
"zheng95z/rgb-to-x",
|
76 |
+
torch_dtype=torch.float16,
|
77 |
+
).to(self.device)
|
78 |
+
self.pipe_rgb2x.scheduler = DDIMScheduler.from_config(
|
79 |
+
self.pipe_rgb2x.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
|
80 |
+
)
|
81 |
+
self.pipe_rgb2x.set_progress_bar_config(disable=True)
|
82 |
+
|
83 |
+
# setup blender
|
84 |
+
self.blender_path = '/tmp/blender-3.2.2-linux-x64/blender'
|
85 |
+
if not os.path.exists(self.blender_path):
|
86 |
+
print("Downloading Blender...")
|
87 |
+
subprocess.run(["wget", "https://download.blender.org/release/Blender3.2/blender-3.2.2-linux-x64.tar.xz", "-O", "/tmp/blender-3.2.2-linux-x64.tar.xz"], check=True)
|
88 |
+
subprocess.run(["tar", "-xf", "/tmp/blender-3.2.2-linux-x64.tar.xz", "-C", "/tmp"], check=True)
|
89 |
+
print("Blender downloaded and extracted.")
|
90 |
+
|
91 |
+
def load_pipe(self, base_model_id: str, task_name: str) -> DiffusionPipeline:
|
92 |
+
if (
|
93 |
+
base_model_id == self.base_model_id
|
94 |
+
and task_name == self.task_name
|
95 |
+
and hasattr(self, "pipe")
|
96 |
+
and self.pipe is not None
|
97 |
+
):
|
98 |
+
return self.pipe
|
99 |
+
model_id = CONTROLNET_MODEL_IDS[task_name]
|
100 |
+
controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
|
101 |
+
to_upload = False
|
102 |
+
if to_upload:
|
103 |
+
# confirm before uploading
|
104 |
+
confirm = input(f"Do you want to upload {model_id} to the hub? (y/n): ")
|
105 |
+
if confirm.lower() == "y":
|
106 |
+
controlnet.push_to_hub("jingyangcarl/texnet")
|
107 |
+
else:
|
108 |
+
print("Upload cancelled.")
|
109 |
+
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
110 |
+
base_model_id, safety_checker=None, controlnet=controlnet, torch_dtype=torch.float16
|
111 |
+
)
|
112 |
+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
113 |
+
pipe.to(self.device)
|
114 |
+
if self.device.type == "cuda":
|
115 |
+
import os
|
116 |
+
if os.environ.get("SPACES_ZERO_GPU", "0") == "1":
|
117 |
+
# when running on ZeroGPU, enable CPU offload
|
118 |
+
# pipe.enable_xformers_memory_efficient_attention() doens't work
|
119 |
+
# pipe.enable_model_cpu_offload()
|
120 |
+
pass
|
121 |
+
else:
|
122 |
+
pipe.enable_xformers_memory_efficient_attention()
|
123 |
+
torch.cuda.empty_cache()
|
124 |
+
gc.collect()
|
125 |
+
self.base_model_id = base_model_id
|
126 |
+
self.task_name = task_name
|
127 |
+
return pipe
|
128 |
+
|
129 |
+
def set_base_model(self, base_model_id: str) -> str:
|
130 |
+
if not base_model_id or base_model_id == self.base_model_id:
|
131 |
+
return self.base_model_id
|
132 |
+
del self.pipe
|
133 |
+
torch.cuda.empty_cache()
|
134 |
+
gc.collect()
|
135 |
+
try:
|
136 |
+
self.pipe = self.load_pipe(base_model_id, self.task_name)
|
137 |
+
except Exception: # noqa: BLE001
|
138 |
+
self.pipe = self.load_pipe(self.base_model_id, self.task_name)
|
139 |
+
return self.base_model_id
|
140 |
+
|
141 |
+
def load_controlnet_weight(self, task_name: str) -> None:
|
142 |
+
if task_name == self.task_name:
|
143 |
+
return
|
144 |
+
if self.pipe is not None and hasattr(self.pipe, "controlnet"):
|
145 |
+
del self.pipe.controlnet
|
146 |
+
torch.cuda.empty_cache()
|
147 |
+
gc.collect()
|
148 |
+
model_id = CONTROLNET_MODEL_IDS[task_name]
|
149 |
+
controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
|
150 |
+
controlnet.to(self.device)
|
151 |
+
torch.cuda.empty_cache()
|
152 |
+
gc.collect()
|
153 |
+
self.pipe.controlnet = controlnet
|
154 |
+
self.task_name = task_name
|
155 |
+
|
156 |
+
def get_prompt(self, prompt: str, additional_prompt: str) -> str:
|
157 |
+
return additional_prompt if not prompt else f"{prompt}, {additional_prompt}"
|
158 |
+
|
159 |
+
# @spaces.GPU #[uncomment to use ZeroGPU]
|
160 |
+
@torch.autocast("cuda")
|
161 |
+
def run_pipe(
|
162 |
+
self,
|
163 |
+
prompt: str,
|
164 |
+
negative_prompt: str,
|
165 |
+
control_image: PIL.Image.Image,
|
166 |
+
num_images: int,
|
167 |
+
num_steps: int,
|
168 |
+
guidance_scale: float,
|
169 |
+
seed: int,
|
170 |
+
) -> list[PIL.Image.Image]:
|
171 |
+
generator = torch.Generator().manual_seed(seed)
|
172 |
+
# self.pipe.to(self.device)
|
173 |
+
return self.pipe(
|
174 |
+
prompt=prompt,
|
175 |
+
negative_prompt=negative_prompt,
|
176 |
+
guidance_scale=guidance_scale,
|
177 |
+
num_images_per_prompt=num_images,
|
178 |
+
num_inference_steps=num_steps,
|
179 |
+
generator=generator,
|
180 |
+
image=control_image,
|
181 |
+
).images
|
182 |
+
|
183 |
+
# @spaces.GPU #[uncomment to use ZeroGPU]
|
184 |
+
@torch.inference_mode()
|
185 |
+
def process_texnet(
|
186 |
+
self,
|
187 |
+
obj_name: str,
|
188 |
+
represented_image: np.ndarray | None, # not used
|
189 |
+
image: np.ndarray,
|
190 |
+
prompt: str,
|
191 |
+
additional_prompt: str,
|
192 |
+
negative_prompt: str,
|
193 |
+
num_images: int,
|
194 |
+
image_resolution: int,
|
195 |
+
num_steps: int,
|
196 |
+
guidance_scale: float,
|
197 |
+
seed: int,
|
198 |
+
low_threshold: int,
|
199 |
+
high_threshold: int,
|
200 |
+
) -> list[PIL.Image.Image]:
|
201 |
+
if image is None:
|
202 |
+
raise ValueError
|
203 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
204 |
+
raise ValueError
|
205 |
+
if num_images > MAX_NUM_IMAGES:
|
206 |
+
raise ValueError
|
207 |
+
|
208 |
+
prompt_nospace = prompt.replace(' ', '_')
|
209 |
+
|
210 |
+
# self.preprocessor.load("texnet")
|
211 |
+
# control_image = self.preprocessor(
|
212 |
+
# image=image, low_threshold=low_threshold, high_threshold=high_threshold, image_resolution=image_resolution, output_type="pil"
|
213 |
+
# )
|
214 |
+
|
215 |
+
# self.load_controlnet_weight("texnet")
|
216 |
+
# tex_coarse = self.run_pipe(
|
217 |
+
# prompt=self.get_prompt(prompt, additional_prompt),
|
218 |
+
# negative_prompt=negative_prompt,
|
219 |
+
# control_image=control_image,
|
220 |
+
# num_images=num_images,
|
221 |
+
# num_steps=num_steps,
|
222 |
+
# guidance_scale=guidance_scale,
|
223 |
+
# seed=seed,
|
224 |
+
# )
|
225 |
+
|
226 |
+
# # use img2img pipeline
|
227 |
+
# self.pipe_backup = self.pipe
|
228 |
+
# self.pipe = self.pipe_base
|
229 |
+
|
230 |
+
# # refine
|
231 |
+
tex_fine = []
|
232 |
+
mesh_fine = []
|
233 |
+
# for result_coarse in tex_coarse:
|
234 |
+
# # clean up GPU cache
|
235 |
+
# torch.cuda.empty_cache()
|
236 |
+
# gc.collect()
|
237 |
+
|
238 |
+
# # masking
|
239 |
+
# mask = (np.array(control_image).sum(axis=-1) == 0)[...,None]
|
240 |
+
# image_masked = PIL.Image.fromarray(np.where(mask, control_image, result_coarse))
|
241 |
+
# image_blurry = transforms.GaussianBlur(kernel_size=5, sigma=1)(image_masked)
|
242 |
+
# result_fine = self.run_pipe(
|
243 |
+
# # prompt=prompt,
|
244 |
+
# prompt=self.get_prompt(prompt, additional_prompt),
|
245 |
+
# negative_prompt=negative_prompt,
|
246 |
+
# control_image=image_blurry,
|
247 |
+
# num_images=1,
|
248 |
+
# num_steps=num_steps,
|
249 |
+
# guidance_scale=guidance_scale,
|
250 |
+
# seed=seed,
|
251 |
+
# )[0]
|
252 |
+
# result_fine = PIL.Image.fromarray(np.where(mask, control_image, result_fine))
|
253 |
+
# tex_fine.append(result_fine)
|
254 |
+
|
255 |
+
temp_out_path = tempfile.mkdtemp()
|
256 |
+
temp_out_path = 'output'
|
257 |
+
|
258 |
+
# put text2tex here,
|
259 |
+
args = init_args()
|
260 |
+
args.input_dir = f'examples/{obj_name}/'
|
261 |
+
args.output_dir = os.path.join(temp_out_path, f'{obj_name}/{prompt_nospace}')
|
262 |
+
args.obj_name = obj_name
|
263 |
+
args.obj_file = 'mesh.obj'
|
264 |
+
args.prompt = f'{prompt} {obj_name}'
|
265 |
+
args.add_view_to_prompt = True
|
266 |
+
args.ddim_steps = 5
|
267 |
+
# args.ddim_steps = 50
|
268 |
+
args.new_strength = 1.0
|
269 |
+
args.update_strength = 0.3
|
270 |
+
args.view_threshold = 0.1
|
271 |
+
args.blend = 0
|
272 |
+
args.dist = 1
|
273 |
+
args.num_viewpoints = 2
|
274 |
+
# args.num_viewpoints = 36
|
275 |
+
args.viewpoint_mode = 'predefined'
|
276 |
+
args.use_principle = True
|
277 |
+
args.update_steps = 2
|
278 |
+
# args.update_steps = 20
|
279 |
+
args.update_mode = 'heuristic'
|
280 |
+
args.seed = 42
|
281 |
+
args.post_process = True
|
282 |
+
args.device = '2080'
|
283 |
+
args.uv_size = 1000
|
284 |
+
args.image_size = 512
|
285 |
+
# args.image_size = 768
|
286 |
+
args.use_objaverse = True # assume the mesh is normalized with y-axis as up
|
287 |
+
output_dir = text2tex_call(args)
|
288 |
+
|
289 |
+
# get the texture and mesh with underscore '_post', which is the id of the last mesh, should be good for the visual
|
290 |
+
post_idx = glob(os.path.join(output_dir, 'update', 'mesh', "*_post.png"))[0].split('/')[-1].split('_')[0]
|
291 |
+
|
292 |
+
tex_fine.append(PIL.Image.open(os.path.join(output_dir, 'update', 'mesh', f"{post_idx}.png")).convert("RGB"))
|
293 |
+
mesh_fine.append(os.path.join(output_dir, 'update', 'mesh', f"{post_idx}.obj"))
|
294 |
+
torch.cuda.empty_cache()
|
295 |
+
|
296 |
+
# restore the original pipe
|
297 |
+
# self.pipe = self.pipe_backup
|
298 |
+
|
299 |
+
# use rgb2x for now for generating the texture
|
300 |
+
def rgb2x(
|
301 |
+
pipeline,
|
302 |
+
photo,
|
303 |
+
inference_step = 50,
|
304 |
+
num_samples = 1,
|
305 |
+
):
|
306 |
+
generator = torch.Generator(device="cuda").manual_seed(seed)
|
307 |
+
|
308 |
+
# Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop
|
309 |
+
old_height = photo.shape[1]
|
310 |
+
old_width = photo.shape[2]
|
311 |
+
new_height = old_height
|
312 |
+
new_width = old_width
|
313 |
+
radio = old_height / old_width
|
314 |
+
max_side = 1000
|
315 |
+
if old_height > old_width:
|
316 |
+
new_height = max_side
|
317 |
+
new_width = int(new_height / radio)
|
318 |
+
else:
|
319 |
+
new_width = max_side
|
320 |
+
new_height = int(new_width * radio)
|
321 |
+
|
322 |
+
if new_width % 8 != 0 or new_height % 8 != 0:
|
323 |
+
new_width = new_width // 8 * 8
|
324 |
+
new_height = new_height // 8 * 8
|
325 |
+
|
326 |
+
photo = torchvision.transforms.Resize((new_height, new_width))(photo)
|
327 |
+
|
328 |
+
required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
|
329 |
+
prompts = {
|
330 |
+
"albedo": "Albedo (diffuse basecolor)",
|
331 |
+
"normal": "Camera-space Normal",
|
332 |
+
"roughness": "Roughness",
|
333 |
+
"metallic": "Metallicness",
|
334 |
+
"irradiance": "Irradiance (diffuse lighting)",
|
335 |
+
}
|
336 |
+
|
337 |
+
return_list = []
|
338 |
+
for i in tqdm(range(num_samples), desc="Running Pipeline", leave=False):
|
339 |
+
for aov_name in required_aovs:
|
340 |
+
prompt = prompts[aov_name]
|
341 |
+
generated_image = pipeline(
|
342 |
+
prompt=prompt,
|
343 |
+
photo=photo,
|
344 |
+
num_inference_steps=inference_step,
|
345 |
+
height=new_height,
|
346 |
+
width=new_width,
|
347 |
+
generator=generator,
|
348 |
+
required_aovs=[aov_name],
|
349 |
+
).images[0][0]
|
350 |
+
|
351 |
+
generated_image = torchvision.transforms.Resize(
|
352 |
+
(old_height, old_width)
|
353 |
+
)(generated_image)
|
354 |
+
|
355 |
+
# generated_image = (generated_image, f"Generated {aov_name} {i}")
|
356 |
+
# generated_image = (generated_image, f"{aov_name}")
|
357 |
+
return_list.append(generated_image)
|
358 |
+
|
359 |
+
return photo, return_list, prompts
|
360 |
+
|
361 |
+
# Load rgb2x pipeline
|
362 |
+
_, preds, prompts = rgb2x(self.pipe_rgb2x, torchvision.transforms.PILToTensor()(tex_fine[0]).to(self.pipe.device), inference_step=num_steps, num_samples=num_images)
|
363 |
+
|
364 |
+
intrinsic_dir = os.path.join(output_dir, 'intrinsic')
|
365 |
+
use_text2tex = True
|
366 |
+
if use_text2tex:
|
367 |
+
base_color_path = image_to_temp_path(tex_fine[0], "base_color", out_dir=intrinsic_dir)
|
368 |
+
normal_map_path = image_to_temp_path(preds[0], "normal_map", out_dir=intrinsic_dir)
|
369 |
+
roughness_path = image_to_temp_path(preds[1], "roughness", out_dir=intrinsic_dir)
|
370 |
+
metallic_path = image_to_temp_path(preds[2], "metallic", out_dir=intrinsic_dir)
|
371 |
+
else:
|
372 |
+
base_color_path = image_to_temp_path(tex_fine[0].rotate(90), "base_color", out_dir=intrinsic_dir)
|
373 |
+
normal_map_path = image_to_temp_path(preds[0].rotate(90), "normal_map", out_dir=intrinsic_dir)
|
374 |
+
roughness_path = image_to_temp_path(preds[1].rotate(90), "roughness", out_dir=intrinsic_dir)
|
375 |
+
metallic_path = image_to_temp_path(preds[2].rotate(90), "metallic", out_dir=intrinsic_dir)
|
376 |
+
current_timecode = time.strftime("%Y%m%d_%H%M%S")
|
377 |
+
# output_blend_path = os.path.join(os.getcwd(), "output", f"{obj_name}_{prompt_nospace}_{current_timecode}.blend") # replace with desired output path
|
378 |
+
output_blend_path = os.path.join(tempfile.mkdtemp(), f"{obj_name}_{prompt_nospace}_{current_timecode}.blend") # replace with desired output path
|
379 |
+
os.makedirs(os.path.dirname(output_blend_path), exist_ok=True)
|
380 |
+
|
381 |
+
def run_blend_generation(
|
382 |
+
blender_path,
|
383 |
+
generate_script_path,
|
384 |
+
obj_path,
|
385 |
+
base_color_path,
|
386 |
+
normal_map_path,
|
387 |
+
roughness_path,
|
388 |
+
metallic_path,
|
389 |
+
output_blend
|
390 |
+
):
|
391 |
+
cmd = [
|
392 |
+
blender_path, "--background", "--python", generate_script_path, "--",
|
393 |
+
obj_path, base_color_path, normal_map_path, roughness_path, metallic_path, output_blend
|
394 |
+
]
|
395 |
+
subprocess.run(cmd, check=True)
|
396 |
+
|
397 |
+
# check if the blender_path exists, if not download
|
398 |
+
run_blend_generation(
|
399 |
+
blender_path=self.blender_path,
|
400 |
+
generate_script_path="rgb2x/generate_blend.py",
|
401 |
+
# obj_path=f"examples/{obj_name}/mesh.obj", # replace with actual mesh path
|
402 |
+
obj_path=mesh_fine[0], # replace with actual mesh path
|
403 |
+
base_color_path=base_color_path,
|
404 |
+
normal_map_path=normal_map_path,
|
405 |
+
roughness_path=roughness_path,
|
406 |
+
metallic_path=metallic_path,
|
407 |
+
output_blend=output_blend_path # replace with desired output path
|
408 |
+
)
|
409 |
+
|
410 |
+
# gallary
|
411 |
+
return [*tex_fine], [preds[1]], [preds[2]], [preds[3]], [output_blend_path]
|
412 |
+
|
413 |
+
# @spaces.GPU #[uncomment to use ZeroGPU]
|
414 |
+
@torch.inference_mode()
|
415 |
+
def process_canny(
|
416 |
+
self,
|
417 |
+
image: np.ndarray,
|
418 |
+
prompt: str,
|
419 |
+
additional_prompt: str,
|
420 |
+
negative_prompt: str,
|
421 |
+
num_images: int,
|
422 |
+
image_resolution: int,
|
423 |
+
num_steps: int,
|
424 |
+
guidance_scale: float,
|
425 |
+
seed: int,
|
426 |
+
low_threshold: int,
|
427 |
+
high_threshold: int,
|
428 |
+
) -> list[PIL.Image.Image]:
|
429 |
+
if image is None:
|
430 |
+
raise ValueError
|
431 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
432 |
+
raise ValueError
|
433 |
+
if num_images > MAX_NUM_IMAGES:
|
434 |
+
raise ValueError
|
435 |
+
|
436 |
+
self.preprocessor.load("Canny")
|
437 |
+
control_image = self.preprocessor(
|
438 |
+
image=image, low_threshold=low_threshold, high_threshold=high_threshold, detect_resolution=image_resolution
|
439 |
+
)
|
440 |
+
|
441 |
+
self.load_controlnet_weight("Canny")
|
442 |
+
results = self.run_pipe(
|
443 |
+
prompt=self.get_prompt(prompt, additional_prompt),
|
444 |
+
negative_prompt=negative_prompt,
|
445 |
+
control_image=control_image,
|
446 |
+
num_images=num_images,
|
447 |
+
num_steps=num_steps,
|
448 |
+
guidance_scale=guidance_scale,
|
449 |
+
seed=seed,
|
450 |
+
)
|
451 |
+
return [control_image, *results]
|
452 |
+
|
453 |
+
@torch.inference_mode()
|
454 |
+
def process_mlsd(
|
455 |
+
self,
|
456 |
+
image: np.ndarray,
|
457 |
+
prompt: str,
|
458 |
+
additional_prompt: str,
|
459 |
+
negative_prompt: str,
|
460 |
+
num_images: int,
|
461 |
+
image_resolution: int,
|
462 |
+
preprocess_resolution: int,
|
463 |
+
num_steps: int,
|
464 |
+
guidance_scale: float,
|
465 |
+
seed: int,
|
466 |
+
value_threshold: float,
|
467 |
+
distance_threshold: float,
|
468 |
+
) -> list[PIL.Image.Image]:
|
469 |
+
if image is None:
|
470 |
+
raise ValueError
|
471 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
472 |
+
raise ValueError
|
473 |
+
if num_images > MAX_NUM_IMAGES:
|
474 |
+
raise ValueError
|
475 |
+
|
476 |
+
self.preprocessor.load("MLSD")
|
477 |
+
control_image = self.preprocessor(
|
478 |
+
image=image,
|
479 |
+
image_resolution=image_resolution,
|
480 |
+
detect_resolution=preprocess_resolution,
|
481 |
+
thr_v=value_threshold,
|
482 |
+
thr_d=distance_threshold,
|
483 |
+
)
|
484 |
+
self.load_controlnet_weight("MLSD")
|
485 |
+
results = self.run_pipe(
|
486 |
+
prompt=self.get_prompt(prompt, additional_prompt),
|
487 |
+
negative_prompt=negative_prompt,
|
488 |
+
control_image=control_image,
|
489 |
+
num_images=num_images,
|
490 |
+
num_steps=num_steps,
|
491 |
+
guidance_scale=guidance_scale,
|
492 |
+
seed=seed,
|
493 |
+
)
|
494 |
+
return [control_image, *results]
|
495 |
+
|
496 |
+
@torch.inference_mode()
|
497 |
+
def process_scribble(
|
498 |
+
self,
|
499 |
+
image: np.ndarray,
|
500 |
+
prompt: str,
|
501 |
+
additional_prompt: str,
|
502 |
+
negative_prompt: str,
|
503 |
+
num_images: int,
|
504 |
+
image_resolution: int,
|
505 |
+
preprocess_resolution: int,
|
506 |
+
num_steps: int,
|
507 |
+
guidance_scale: float,
|
508 |
+
seed: int,
|
509 |
+
preprocessor_name: str,
|
510 |
+
) -> list[PIL.Image.Image]:
|
511 |
+
if image is None:
|
512 |
+
raise ValueError
|
513 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
514 |
+
raise ValueError
|
515 |
+
if num_images > MAX_NUM_IMAGES:
|
516 |
+
raise ValueError
|
517 |
+
|
518 |
+
if preprocessor_name == "None":
|
519 |
+
image = HWC3(image)
|
520 |
+
image = resize_image(image, resolution=image_resolution)
|
521 |
+
control_image = PIL.Image.fromarray(image)
|
522 |
+
elif preprocessor_name == "HED":
|
523 |
+
self.preprocessor.load(preprocessor_name)
|
524 |
+
control_image = self.preprocessor(
|
525 |
+
image=image,
|
526 |
+
image_resolution=image_resolution,
|
527 |
+
detect_resolution=preprocess_resolution,
|
528 |
+
scribble=False,
|
529 |
+
)
|
530 |
+
elif preprocessor_name == "PidiNet":
|
531 |
+
self.preprocessor.load(preprocessor_name)
|
532 |
+
control_image = self.preprocessor(
|
533 |
+
image=image,
|
534 |
+
image_resolution=image_resolution,
|
535 |
+
detect_resolution=preprocess_resolution,
|
536 |
+
safe=False,
|
537 |
+
)
|
538 |
+
self.load_controlnet_weight("scribble")
|
539 |
+
results = self.run_pipe(
|
540 |
+
prompt=self.get_prompt(prompt, additional_prompt),
|
541 |
+
negative_prompt=negative_prompt,
|
542 |
+
control_image=control_image,
|
543 |
+
num_images=num_images,
|
544 |
+
num_steps=num_steps,
|
545 |
+
guidance_scale=guidance_scale,
|
546 |
+
seed=seed,
|
547 |
+
)
|
548 |
+
return [control_image, *results]
|
549 |
+
|
550 |
+
@torch.inference_mode()
|
551 |
+
def process_scribble_interactive(
|
552 |
+
self,
|
553 |
+
image_and_mask: dict[str, np.ndarray | list[np.ndarray]] | None,
|
554 |
+
prompt: str,
|
555 |
+
additional_prompt: str,
|
556 |
+
negative_prompt: str,
|
557 |
+
num_images: int,
|
558 |
+
image_resolution: int,
|
559 |
+
num_steps: int,
|
560 |
+
guidance_scale: float,
|
561 |
+
seed: int,
|
562 |
+
) -> list[PIL.Image.Image]:
|
563 |
+
if image_and_mask is None:
|
564 |
+
raise ValueError
|
565 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
566 |
+
raise ValueError
|
567 |
+
if num_images > MAX_NUM_IMAGES:
|
568 |
+
raise ValueError
|
569 |
+
|
570 |
+
image = 255 - image_and_mask["composite"] # type: ignore
|
571 |
+
image = HWC3(image)
|
572 |
+
image = resize_image(image, resolution=image_resolution)
|
573 |
+
control_image = PIL.Image.fromarray(image)
|
574 |
+
|
575 |
+
self.load_controlnet_weight("scribble")
|
576 |
+
results = self.run_pipe(
|
577 |
+
prompt=self.get_prompt(prompt, additional_prompt),
|
578 |
+
negative_prompt=negative_prompt,
|
579 |
+
control_image=control_image,
|
580 |
+
num_images=num_images,
|
581 |
+
num_steps=num_steps,
|
582 |
+
guidance_scale=guidance_scale,
|
583 |
+
seed=seed,
|
584 |
+
)
|
585 |
+
return [control_image, *results]
|
586 |
+
|
587 |
+
@torch.inference_mode()
|
588 |
+
def process_softedge(
|
589 |
+
self,
|
590 |
+
image: np.ndarray,
|
591 |
+
prompt: str,
|
592 |
+
additional_prompt: str,
|
593 |
+
negative_prompt: str,
|
594 |
+
num_images: int,
|
595 |
+
image_resolution: int,
|
596 |
+
preprocess_resolution: int,
|
597 |
+
num_steps: int,
|
598 |
+
guidance_scale: float,
|
599 |
+
seed: int,
|
600 |
+
preprocessor_name: str,
|
601 |
+
) -> list[PIL.Image.Image]:
|
602 |
+
if image is None:
|
603 |
+
raise ValueError
|
604 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
605 |
+
raise ValueError
|
606 |
+
if num_images > MAX_NUM_IMAGES:
|
607 |
+
raise ValueError
|
608 |
+
|
609 |
+
if preprocessor_name == "None":
|
610 |
+
image = HWC3(image)
|
611 |
+
image = resize_image(image, resolution=image_resolution)
|
612 |
+
control_image = PIL.Image.fromarray(image)
|
613 |
+
elif preprocessor_name in ["HED", "HED safe"]:
|
614 |
+
safe = "safe" in preprocessor_name
|
615 |
+
self.preprocessor.load("HED")
|
616 |
+
control_image = self.preprocessor(
|
617 |
+
image=image,
|
618 |
+
image_resolution=image_resolution,
|
619 |
+
detect_resolution=preprocess_resolution,
|
620 |
+
scribble=safe,
|
621 |
+
)
|
622 |
+
elif preprocessor_name in ["PidiNet", "PidiNet safe"]:
|
623 |
+
safe = "safe" in preprocessor_name
|
624 |
+
self.preprocessor.load("PidiNet")
|
625 |
+
control_image = self.preprocessor(
|
626 |
+
image=image,
|
627 |
+
image_resolution=image_resolution,
|
628 |
+
detect_resolution=preprocess_resolution,
|
629 |
+
safe=safe,
|
630 |
+
)
|
631 |
+
else:
|
632 |
+
raise ValueError
|
633 |
+
self.load_controlnet_weight("softedge")
|
634 |
+
results = self.run_pipe(
|
635 |
+
prompt=self.get_prompt(prompt, additional_prompt),
|
636 |
+
negative_prompt=negative_prompt,
|
637 |
+
control_image=control_image,
|
638 |
+
num_images=num_images,
|
639 |
+
num_steps=num_steps,
|
640 |
+
guidance_scale=guidance_scale,
|
641 |
+
seed=seed,
|
642 |
+
)
|
643 |
+
return [control_image, *results]
|
644 |
+
|
645 |
+
@torch.inference_mode()
|
646 |
+
def process_openpose(
|
647 |
+
self,
|
648 |
+
image: np.ndarray,
|
649 |
+
prompt: str,
|
650 |
+
additional_prompt: str,
|
651 |
+
negative_prompt: str,
|
652 |
+
num_images: int,
|
653 |
+
image_resolution: int,
|
654 |
+
preprocess_resolution: int,
|
655 |
+
num_steps: int,
|
656 |
+
guidance_scale: float,
|
657 |
+
seed: int,
|
658 |
+
preprocessor_name: str,
|
659 |
+
) -> list[PIL.Image.Image]:
|
660 |
+
if image is None:
|
661 |
+
raise ValueError
|
662 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
663 |
+
raise ValueError
|
664 |
+
if num_images > MAX_NUM_IMAGES:
|
665 |
+
raise ValueError
|
666 |
+
|
667 |
+
if preprocessor_name == "None":
|
668 |
+
image = HWC3(image)
|
669 |
+
image = resize_image(image, resolution=image_resolution)
|
670 |
+
control_image = PIL.Image.fromarray(image)
|
671 |
+
else:
|
672 |
+
self.preprocessor.load("Openpose")
|
673 |
+
control_image = self.preprocessor(
|
674 |
+
image=image,
|
675 |
+
image_resolution=image_resolution,
|
676 |
+
detect_resolution=preprocess_resolution,
|
677 |
+
hand_and_face=True,
|
678 |
+
)
|
679 |
+
self.load_controlnet_weight("Openpose")
|
680 |
+
results = self.run_pipe(
|
681 |
+
prompt=self.get_prompt(prompt, additional_prompt),
|
682 |
+
negative_prompt=negative_prompt,
|
683 |
+
control_image=control_image,
|
684 |
+
num_images=num_images,
|
685 |
+
num_steps=num_steps,
|
686 |
+
guidance_scale=guidance_scale,
|
687 |
+
seed=seed,
|
688 |
+
)
|
689 |
+
return [control_image, *results]
|
690 |
+
|
691 |
+
@torch.inference_mode()
|
692 |
+
def process_segmentation(
|
693 |
+
self,
|
694 |
+
image: np.ndarray,
|
695 |
+
prompt: str,
|
696 |
+
additional_prompt: str,
|
697 |
+
negative_prompt: str,
|
698 |
+
num_images: int,
|
699 |
+
image_resolution: int,
|
700 |
+
preprocess_resolution: int,
|
701 |
+
num_steps: int,
|
702 |
+
guidance_scale: float,
|
703 |
+
seed: int,
|
704 |
+
preprocessor_name: str,
|
705 |
+
) -> list[PIL.Image.Image]:
|
706 |
+
if image is None:
|
707 |
+
raise ValueError
|
708 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
709 |
+
raise ValueError
|
710 |
+
if num_images > MAX_NUM_IMAGES:
|
711 |
+
raise ValueError
|
712 |
+
|
713 |
+
if preprocessor_name == "None":
|
714 |
+
image = HWC3(image)
|
715 |
+
image = resize_image(image, resolution=image_resolution)
|
716 |
+
control_image = PIL.Image.fromarray(image)
|
717 |
+
else:
|
718 |
+
self.preprocessor.load(preprocessor_name)
|
719 |
+
control_image = self.preprocessor(
|
720 |
+
image=image,
|
721 |
+
image_resolution=image_resolution,
|
722 |
+
detect_resolution=preprocess_resolution,
|
723 |
+
)
|
724 |
+
self.load_controlnet_weight("segmentation")
|
725 |
+
results = self.run_pipe(
|
726 |
+
prompt=self.get_prompt(prompt, additional_prompt),
|
727 |
+
negative_prompt=negative_prompt,
|
728 |
+
control_image=control_image,
|
729 |
+
num_images=num_images,
|
730 |
+
num_steps=num_steps,
|
731 |
+
guidance_scale=guidance_scale,
|
732 |
+
seed=seed,
|
733 |
+
)
|
734 |
+
return [control_image, *results]
|
735 |
+
|
736 |
+
@torch.inference_mode()
|
737 |
+
def process_depth(
|
738 |
+
self,
|
739 |
+
image: np.ndarray,
|
740 |
+
prompt: str,
|
741 |
+
additional_prompt: str,
|
742 |
+
negative_prompt: str,
|
743 |
+
num_images: int,
|
744 |
+
image_resolution: int,
|
745 |
+
preprocess_resolution: int,
|
746 |
+
num_steps: int,
|
747 |
+
guidance_scale: float,
|
748 |
+
seed: int,
|
749 |
+
preprocessor_name: str,
|
750 |
+
) -> list[PIL.Image.Image]:
|
751 |
+
if image is None:
|
752 |
+
raise ValueError
|
753 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
754 |
+
raise ValueError
|
755 |
+
if num_images > MAX_NUM_IMAGES:
|
756 |
+
raise ValueError
|
757 |
+
|
758 |
+
if preprocessor_name == "None":
|
759 |
+
image = HWC3(image)
|
760 |
+
image = resize_image(image, resolution=image_resolution)
|
761 |
+
control_image = PIL.Image.fromarray(image)
|
762 |
+
else:
|
763 |
+
self.preprocessor.load(preprocessor_name)
|
764 |
+
control_image = self.preprocessor(
|
765 |
+
image=image,
|
766 |
+
image_resolution=image_resolution,
|
767 |
+
detect_resolution=preprocess_resolution,
|
768 |
+
)
|
769 |
+
self.load_controlnet_weight("depth")
|
770 |
+
results = self.run_pipe(
|
771 |
+
prompt=self.get_prompt(prompt, additional_prompt),
|
772 |
+
negative_prompt=negative_prompt,
|
773 |
+
control_image=control_image,
|
774 |
+
num_images=num_images,
|
775 |
+
num_steps=num_steps,
|
776 |
+
guidance_scale=guidance_scale,
|
777 |
+
seed=seed,
|
778 |
+
)
|
779 |
+
return [control_image, *results]
|
780 |
+
|
781 |
+
@torch.inference_mode()
|
782 |
+
def process_normal(
|
783 |
+
self,
|
784 |
+
image: np.ndarray,
|
785 |
+
prompt: str,
|
786 |
+
additional_prompt: str,
|
787 |
+
negative_prompt: str,
|
788 |
+
num_images: int,
|
789 |
+
image_resolution: int,
|
790 |
+
preprocess_resolution: int,
|
791 |
+
num_steps: int,
|
792 |
+
guidance_scale: float,
|
793 |
+
seed: int,
|
794 |
+
preprocessor_name: str,
|
795 |
+
) -> list[PIL.Image.Image]:
|
796 |
+
if image is None:
|
797 |
+
raise ValueError
|
798 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
799 |
+
raise ValueError
|
800 |
+
if num_images > MAX_NUM_IMAGES:
|
801 |
+
raise ValueError
|
802 |
+
|
803 |
+
if preprocessor_name == "None":
|
804 |
+
image = HWC3(image)
|
805 |
+
image = resize_image(image, resolution=image_resolution)
|
806 |
+
control_image = PIL.Image.fromarray(image)
|
807 |
+
else:
|
808 |
+
self.preprocessor.load("NormalBae")
|
809 |
+
control_image = self.preprocessor(
|
810 |
+
image=image,
|
811 |
+
image_resolution=image_resolution,
|
812 |
+
detect_resolution=preprocess_resolution,
|
813 |
+
)
|
814 |
+
self.load_controlnet_weight("NormalBae")
|
815 |
+
results = self.run_pipe(
|
816 |
+
prompt=self.get_prompt(prompt, additional_prompt),
|
817 |
+
negative_prompt=negative_prompt,
|
818 |
+
control_image=control_image,
|
819 |
+
num_images=num_images,
|
820 |
+
num_steps=num_steps,
|
821 |
+
guidance_scale=guidance_scale,
|
822 |
+
seed=seed,
|
823 |
+
)
|
824 |
+
return [control_image, *results]
|
825 |
+
|
826 |
+
@torch.inference_mode()
|
827 |
+
def process_lineart(
|
828 |
+
self,
|
829 |
+
image: np.ndarray,
|
830 |
+
prompt: str,
|
831 |
+
additional_prompt: str,
|
832 |
+
negative_prompt: str,
|
833 |
+
num_images: int,
|
834 |
+
image_resolution: int,
|
835 |
+
preprocess_resolution: int,
|
836 |
+
num_steps: int,
|
837 |
+
guidance_scale: float,
|
838 |
+
seed: int,
|
839 |
+
preprocessor_name: str,
|
840 |
+
) -> list[PIL.Image.Image]:
|
841 |
+
if image is None:
|
842 |
+
raise ValueError
|
843 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
844 |
+
raise ValueError
|
845 |
+
if num_images > MAX_NUM_IMAGES:
|
846 |
+
raise ValueError
|
847 |
+
|
848 |
+
if preprocessor_name in ["None", "None (anime)"]:
|
849 |
+
image = HWC3(image)
|
850 |
+
image = resize_image(image, resolution=image_resolution)
|
851 |
+
control_image = PIL.Image.fromarray(image)
|
852 |
+
elif preprocessor_name in ["Lineart", "Lineart coarse"]:
|
853 |
+
coarse = "coarse" in preprocessor_name
|
854 |
+
self.preprocessor.load("Lineart")
|
855 |
+
control_image = self.preprocessor(
|
856 |
+
image=image,
|
857 |
+
image_resolution=image_resolution,
|
858 |
+
detect_resolution=preprocess_resolution,
|
859 |
+
coarse=coarse,
|
860 |
+
)
|
861 |
+
elif preprocessor_name == "Lineart (anime)":
|
862 |
+
self.preprocessor.load("LineartAnime")
|
863 |
+
control_image = self.preprocessor(
|
864 |
+
image=image,
|
865 |
+
image_resolution=image_resolution,
|
866 |
+
detect_resolution=preprocess_resolution,
|
867 |
+
)
|
868 |
+
if "anime" in preprocessor_name:
|
869 |
+
self.load_controlnet_weight("lineart_anime")
|
870 |
+
else:
|
871 |
+
self.load_controlnet_weight("lineart")
|
872 |
+
results = self.run_pipe(
|
873 |
+
prompt=self.get_prompt(prompt, additional_prompt),
|
874 |
+
negative_prompt=negative_prompt,
|
875 |
+
control_image=control_image,
|
876 |
+
num_images=num_images,
|
877 |
+
num_steps=num_steps,
|
878 |
+
guidance_scale=guidance_scale,
|
879 |
+
seed=seed,
|
880 |
+
)
|
881 |
+
return [control_image, *results]
|
882 |
+
|
883 |
+
@torch.inference_mode()
|
884 |
+
def process_shuffle(
|
885 |
+
self,
|
886 |
+
image: np.ndarray,
|
887 |
+
prompt: str,
|
888 |
+
additional_prompt: str,
|
889 |
+
negative_prompt: str,
|
890 |
+
num_images: int,
|
891 |
+
image_resolution: int,
|
892 |
+
num_steps: int,
|
893 |
+
guidance_scale: float,
|
894 |
+
seed: int,
|
895 |
+
preprocessor_name: str,
|
896 |
+
) -> list[PIL.Image.Image]:
|
897 |
+
if image is None:
|
898 |
+
raise ValueError
|
899 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
900 |
+
raise ValueError
|
901 |
+
if num_images > MAX_NUM_IMAGES:
|
902 |
+
raise ValueError
|
903 |
+
|
904 |
+
if preprocessor_name == "None":
|
905 |
+
image = HWC3(image)
|
906 |
+
image = resize_image(image, resolution=image_resolution)
|
907 |
+
control_image = PIL.Image.fromarray(image)
|
908 |
+
else:
|
909 |
+
self.preprocessor.load(preprocessor_name)
|
910 |
+
control_image = self.preprocessor(
|
911 |
+
image=image,
|
912 |
+
image_resolution=image_resolution,
|
913 |
+
)
|
914 |
+
self.load_controlnet_weight("shuffle")
|
915 |
+
results = self.run_pipe(
|
916 |
+
prompt=self.get_prompt(prompt, additional_prompt),
|
917 |
+
negative_prompt=negative_prompt,
|
918 |
+
control_image=control_image,
|
919 |
+
num_images=num_images,
|
920 |
+
num_steps=num_steps,
|
921 |
+
guidance_scale=guidance_scale,
|
922 |
+
seed=seed,
|
923 |
+
)
|
924 |
+
return [control_image, *results]
|
925 |
+
|
926 |
+
@torch.inference_mode()
|
927 |
+
def process_ip2p(
|
928 |
+
self,
|
929 |
+
image: np.ndarray,
|
930 |
+
prompt: str,
|
931 |
+
additional_prompt: str,
|
932 |
+
negative_prompt: str,
|
933 |
+
num_images: int,
|
934 |
+
image_resolution: int,
|
935 |
+
num_steps: int,
|
936 |
+
guidance_scale: float,
|
937 |
+
seed: int,
|
938 |
+
) -> list[PIL.Image.Image]:
|
939 |
+
if image is None:
|
940 |
+
raise ValueError
|
941 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
942 |
+
raise ValueError
|
943 |
+
if num_images > MAX_NUM_IMAGES:
|
944 |
+
raise ValueError
|
945 |
+
|
946 |
+
image = HWC3(image)
|
947 |
+
image = resize_image(image, resolution=image_resolution)
|
948 |
+
control_image = PIL.Image.fromarray(image)
|
949 |
+
self.load_controlnet_weight("ip2p")
|
950 |
+
results = self.run_pipe(
|
951 |
+
prompt=self.get_prompt(prompt, additional_prompt),
|
952 |
+
negative_prompt=negative_prompt,
|
953 |
+
control_image=control_image,
|
954 |
+
num_images=num_images,
|
955 |
+
num_steps=num_steps,
|
956 |
+
guidance_scale=guidance_scale,
|
957 |
+
seed=seed,
|
958 |
+
)
|
959 |
+
return [control_image, *results]
|
preprocessor.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
from typing import TYPE_CHECKING
|
3 |
+
|
4 |
+
if TYPE_CHECKING:
|
5 |
+
from collections.abc import Callable
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import PIL.Image
|
9 |
+
import torch
|
10 |
+
from controlnet_aux import (
|
11 |
+
CannyDetector,
|
12 |
+
ContentShuffleDetector,
|
13 |
+
HEDdetector,
|
14 |
+
LineartAnimeDetector,
|
15 |
+
LineartDetector,
|
16 |
+
MidasDetector,
|
17 |
+
MLSDdetector,
|
18 |
+
NormalBaeDetector,
|
19 |
+
OpenposeDetector,
|
20 |
+
PidiNetDetector,
|
21 |
+
)
|
22 |
+
from controlnet_aux.util import HWC3
|
23 |
+
|
24 |
+
from cv_utils import resize_image
|
25 |
+
from depth_estimator import DepthEstimator
|
26 |
+
from image_segmentor import ImageSegmentor
|
27 |
+
|
28 |
+
|
29 |
+
class Preprocessor:
|
30 |
+
MODEL_ID = "lllyasviel/Annotators"
|
31 |
+
|
32 |
+
def __init__(self) -> None:
|
33 |
+
self.model: Callable = None # type: ignore
|
34 |
+
self.name = ""
|
35 |
+
|
36 |
+
def load(self, name: str) -> None: # noqa: C901, PLR0912
|
37 |
+
if name == self.name:
|
38 |
+
return
|
39 |
+
if name == "HED":
|
40 |
+
self.model = HEDdetector.from_pretrained(self.MODEL_ID)
|
41 |
+
elif name == "Midas":
|
42 |
+
self.model = MidasDetector.from_pretrained(self.MODEL_ID)
|
43 |
+
elif name == "MLSD":
|
44 |
+
self.model = MLSDdetector.from_pretrained(self.MODEL_ID)
|
45 |
+
elif name == "Openpose":
|
46 |
+
self.model = OpenposeDetector.from_pretrained(self.MODEL_ID)
|
47 |
+
elif name == "PidiNet":
|
48 |
+
self.model = PidiNetDetector.from_pretrained(self.MODEL_ID)
|
49 |
+
elif name == "NormalBae":
|
50 |
+
self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID)
|
51 |
+
elif name == "Lineart":
|
52 |
+
self.model = LineartDetector.from_pretrained(self.MODEL_ID)
|
53 |
+
elif name == "LineartAnime":
|
54 |
+
self.model = LineartAnimeDetector.from_pretrained(self.MODEL_ID)
|
55 |
+
elif name == "Canny":
|
56 |
+
self.model = CannyDetector()
|
57 |
+
elif name == "ContentShuffle":
|
58 |
+
self.model = ContentShuffleDetector()
|
59 |
+
elif name == "DPT":
|
60 |
+
self.model = DepthEstimator()
|
61 |
+
elif name == "UPerNet":
|
62 |
+
self.model = ImageSegmentor()
|
63 |
+
elif name == 'texnet':
|
64 |
+
self.model = TexnetPreprocessor()
|
65 |
+
else:
|
66 |
+
raise ValueError
|
67 |
+
torch.cuda.empty_cache()
|
68 |
+
gc.collect()
|
69 |
+
self.name = name
|
70 |
+
|
71 |
+
def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image: # noqa: ANN003
|
72 |
+
if self.name == "Canny":
|
73 |
+
if "detect_resolution" in kwargs:
|
74 |
+
detect_resolution = kwargs.pop("detect_resolution")
|
75 |
+
image = np.array(image)
|
76 |
+
image = HWC3(image)
|
77 |
+
image = resize_image(image, resolution=detect_resolution)
|
78 |
+
image = self.model(image, **kwargs)
|
79 |
+
return PIL.Image.fromarray(image)
|
80 |
+
if self.name == "Midas":
|
81 |
+
detect_resolution = kwargs.pop("detect_resolution", 512)
|
82 |
+
image_resolution = kwargs.pop("image_resolution", 512)
|
83 |
+
image = np.array(image)
|
84 |
+
image = HWC3(image)
|
85 |
+
image = resize_image(image, resolution=detect_resolution)
|
86 |
+
image = self.model(image, **kwargs)
|
87 |
+
image = HWC3(image)
|
88 |
+
image = resize_image(image, resolution=image_resolution)
|
89 |
+
return PIL.Image.fromarray(image)
|
90 |
+
return self.model(image, **kwargs)
|
91 |
+
|
92 |
+
|
93 |
+
# https://github.com/huggingface/controlnet_aux/blob/master/src/controlnet_aux/canny/__init__.py
|
94 |
+
class TexnetPreprocessor:
|
95 |
+
def __call__(self, input_image=None, low_threshold=100, high_threshold=200, image_resolution=512, output_type=None, **kwargs):
|
96 |
+
if "img" in kwargs:
|
97 |
+
warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning)
|
98 |
+
input_image = kwargs.pop("img")
|
99 |
+
|
100 |
+
if input_image is None:
|
101 |
+
raise ValueError("input_image must be defined.")
|
102 |
+
|
103 |
+
if not isinstance(input_image, np.ndarray):
|
104 |
+
input_image = np.array(input_image, dtype=np.uint8)
|
105 |
+
output_type = output_type or "pil"
|
106 |
+
else:
|
107 |
+
output_type = output_type or "np"
|
108 |
+
|
109 |
+
input_image = HWC3(input_image)
|
110 |
+
input_image = resize_image(input_image, image_resolution)
|
111 |
+
H, W, C = input_image.shape
|
112 |
+
|
113 |
+
# detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
114 |
+
output_image = input_image.copy()
|
115 |
+
|
116 |
+
if output_type == "pil":
|
117 |
+
# detected_map = Image.fromarray(detected_map)
|
118 |
+
output_image = PIL.Image.fromarray(output_image)
|
119 |
+
|
120 |
+
return output_image
|
push_dataset.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import HfApi
|
2 |
+
api = HfApi()
|
3 |
+
|
4 |
+
api.upload_folder(
|
5 |
+
folder_path="./examples",
|
6 |
+
repo_id="jingyangcarl/matgen",
|
7 |
+
repo_type="space",
|
8 |
+
path_in_repo="examples", # Upload to a specific folder
|
9 |
+
)
|
rgb2x/generate_blend.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import bpy
|
2 |
+
import sys
|
3 |
+
import os
|
4 |
+
|
5 |
+
def create_tex_node(nodes, img_path, label, color_space, location):
|
6 |
+
img = bpy.data.images.load(img_path)
|
7 |
+
tex = nodes.new(type='ShaderNodeTexImage')
|
8 |
+
tex.image = img
|
9 |
+
tex.label = label
|
10 |
+
tex.location = location
|
11 |
+
tex.image.colorspace_settings.name = color_space
|
12 |
+
return tex
|
13 |
+
|
14 |
+
def setup_environment_lighting(hdri_path):
|
15 |
+
if not bpy.data.worlds:
|
16 |
+
bpy.data.worlds.new(name="World")
|
17 |
+
if bpy.context.scene.world is None:
|
18 |
+
bpy.context.scene.world = bpy.data.worlds[0]
|
19 |
+
world = bpy.context.scene.world
|
20 |
+
|
21 |
+
world.use_nodes = True
|
22 |
+
nodes = world.node_tree.nodes
|
23 |
+
links = world.node_tree.links
|
24 |
+
nodes.clear()
|
25 |
+
|
26 |
+
env_tex = nodes.new(type="ShaderNodeTexEnvironment")
|
27 |
+
env_tex.image = bpy.data.images.load(hdri_path)
|
28 |
+
env_tex.location = (-300, 0)
|
29 |
+
|
30 |
+
bg = nodes.new(type="ShaderNodeBackground")
|
31 |
+
bg.location = (0, 0)
|
32 |
+
|
33 |
+
output = nodes.new(type="ShaderNodeOutputWorld")
|
34 |
+
output.location = (300, 0)
|
35 |
+
|
36 |
+
links.new(env_tex.outputs["Color"], bg.inputs["Color"])
|
37 |
+
links.new(bg.outputs["Background"], output.inputs["Surface"])
|
38 |
+
|
39 |
+
def setup_gpu_rendering():
|
40 |
+
bpy.context.scene.render.engine = 'CYCLES'
|
41 |
+
prefs = bpy.context.preferences
|
42 |
+
cprefs = prefs.addons['cycles'].preferences
|
43 |
+
|
44 |
+
# Choose backend depending on GPU type: 'CUDA', 'OPTIX', 'HIP', 'METAL'
|
45 |
+
cprefs.compute_device_type = 'CUDA'
|
46 |
+
bpy.context.scene.cycles.device = 'GPU'
|
47 |
+
|
48 |
+
def generate_blend(obj_path, base_color_path, normal_map_path, roughness_path, metallic_path, output_blend):
|
49 |
+
# Reset scene
|
50 |
+
bpy.ops.wm.read_factory_settings(use_empty=True)
|
51 |
+
|
52 |
+
# Import OBJ
|
53 |
+
bpy.ops.import_scene.obj(filepath=obj_path)
|
54 |
+
obj = bpy.context.selected_objects[0]
|
55 |
+
|
56 |
+
# Create material
|
57 |
+
mat = bpy.data.materials.new(name="BRDF_Material")
|
58 |
+
mat.use_nodes = True
|
59 |
+
nodes = mat.node_tree.nodes
|
60 |
+
links = mat.node_tree.links
|
61 |
+
nodes.clear()
|
62 |
+
|
63 |
+
output = nodes.new(type='ShaderNodeOutputMaterial')
|
64 |
+
output.location = (400, 0)
|
65 |
+
|
66 |
+
principled = nodes.new(type='ShaderNodeBsdfPrincipled')
|
67 |
+
principled.location = (100, 0)
|
68 |
+
links.new(principled.outputs['BSDF'], output.inputs['Surface'])
|
69 |
+
|
70 |
+
# Base Color
|
71 |
+
base_color = create_tex_node(nodes, base_color_path, "Base Color", 'sRGB', (-600, 200))
|
72 |
+
links.new(base_color.outputs['Color'], principled.inputs['Base Color'])
|
73 |
+
|
74 |
+
# Roughness
|
75 |
+
rough = create_tex_node(nodes, roughness_path, "Roughness", 'Non-Color', (-600, 0))
|
76 |
+
links.new(rough.outputs['Color'], principled.inputs['Roughness'])
|
77 |
+
|
78 |
+
# Metallic
|
79 |
+
metal = create_tex_node(nodes, metallic_path, "Metallic", 'Non-Color', (-600, -200))
|
80 |
+
links.new(metal.outputs['Color'], principled.inputs['Metallic'])
|
81 |
+
|
82 |
+
# Normal Map
|
83 |
+
normal_tex = create_tex_node(nodes, normal_map_path, "Normal Map", 'Non-Color', (-800, -400))
|
84 |
+
normal_map = nodes.new(type='ShaderNodeNormalMap')
|
85 |
+
normal_map.location = (-400, -400)
|
86 |
+
links.new(normal_tex.outputs['Color'], normal_map.inputs['Color'])
|
87 |
+
links.new(normal_map.outputs['Normal'], principled.inputs['Normal'])
|
88 |
+
|
89 |
+
# Assign material
|
90 |
+
if obj.data.materials:
|
91 |
+
obj.data.materials[0] = mat
|
92 |
+
else:
|
93 |
+
obj.data.materials.append(mat)
|
94 |
+
|
95 |
+
# Global Illumination using Blender's default forest HDRI
|
96 |
+
blender_data_path = bpy.utils.resource_path('LOCAL')
|
97 |
+
forest_hdri_path = os.path.join(blender_data_path, "datafiles", "studiolights", "world", "forest.exr")
|
98 |
+
print(f"Using HDRI: {forest_hdri_path}")
|
99 |
+
setup_environment_lighting(forest_hdri_path)
|
100 |
+
|
101 |
+
# GPU rendering setup
|
102 |
+
setup_gpu_rendering()
|
103 |
+
|
104 |
+
# Pack textures into .blend
|
105 |
+
bpy.ops.file.pack_all()
|
106 |
+
|
107 |
+
# Set the 3D View to Rendered mode and focus on object
|
108 |
+
for area in bpy.context.screen.areas:
|
109 |
+
if area.type == 'VIEW_3D':
|
110 |
+
for space in area.spaces:
|
111 |
+
if space.type == 'VIEW_3D':
|
112 |
+
space.shading.type = 'RENDERED' # Set viewport shading to Rendered
|
113 |
+
for region in area.regions:
|
114 |
+
if region.type == 'WINDOW':
|
115 |
+
override = {'area': area, 'region': region, 'scene': bpy.context.scene}
|
116 |
+
bpy.ops.view3d.view_all(override, center=True)
|
117 |
+
|
118 |
+
elif area.type == 'NODE_EDITOR':
|
119 |
+
for space in area.spaces:
|
120 |
+
if space.type == 'NODE_EDITOR':
|
121 |
+
space.tree_type = 'ShaderNodeTree' # Switch to Shader Editor
|
122 |
+
space.shader_type = 'OBJECT'
|
123 |
+
|
124 |
+
# Optional: Switch active workspace to Shading (if it exists)
|
125 |
+
for workspace in bpy.data.workspaces:
|
126 |
+
if workspace.name == 'Shading':
|
127 |
+
bpy.context.window.workspace = workspace
|
128 |
+
break
|
129 |
+
|
130 |
+
# Save the .blend file
|
131 |
+
bpy.ops.wm.save_as_mainfile(filepath=output_blend)
|
132 |
+
print(f"✅ Saved .blend file with BRDF, HDRI, GPU: {output_blend}")
|
133 |
+
|
134 |
+
if __name__ == "__main__":
|
135 |
+
argv = sys.argv
|
136 |
+
argv = argv[argv.index("--") + 1:] # Only use args after "--"
|
137 |
+
|
138 |
+
if len(argv) != 6:
|
139 |
+
print("Usage:\n blender --background --python generate_blend.py -- obj base_color normal roughness metallic output.blend")
|
140 |
+
sys.exit(1)
|
141 |
+
|
142 |
+
generate_blend(*argv)
|
rgb2x/gradio_demo_rgb2x.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
from diffusers import DDIMScheduler
|
9 |
+
from load_image import load_exr_image, load_ldr_image
|
10 |
+
from pipeline_rgb2x import StableDiffusionAOVMatEstPipeline
|
11 |
+
|
12 |
+
current_directory = os.path.dirname(os.path.abspath(__file__))
|
13 |
+
|
14 |
+
|
15 |
+
def get_rgb2x_demo():
|
16 |
+
# Load pipeline
|
17 |
+
pipe = StableDiffusionAOVMatEstPipeline.from_pretrained(
|
18 |
+
"zheng95z/rgb-to-x",
|
19 |
+
torch_dtype=torch.float16,
|
20 |
+
cache_dir=os.path.join(current_directory, "model_cache"),
|
21 |
+
).to("cuda")
|
22 |
+
pipe.scheduler = DDIMScheduler.from_config(
|
23 |
+
pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
|
24 |
+
)
|
25 |
+
pipe.set_progress_bar_config(disable=True)
|
26 |
+
pipe.to("cuda")
|
27 |
+
|
28 |
+
# Augmentation
|
29 |
+
def callback(
|
30 |
+
photo,
|
31 |
+
seed,
|
32 |
+
inference_step,
|
33 |
+
num_samples,
|
34 |
+
):
|
35 |
+
generator = torch.Generator(device="cuda").manual_seed(seed)
|
36 |
+
|
37 |
+
if photo.name.endswith(".exr"):
|
38 |
+
photo = load_exr_image(photo.name, tonemaping=True, clamp=True).to("cuda")
|
39 |
+
elif (
|
40 |
+
photo.name.endswith(".png")
|
41 |
+
or photo.name.endswith(".jpg")
|
42 |
+
or photo.name.endswith(".jpeg")
|
43 |
+
):
|
44 |
+
photo = load_ldr_image(photo.name, from_srgb=True).to("cuda")
|
45 |
+
|
46 |
+
# Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop
|
47 |
+
old_height = photo.shape[1]
|
48 |
+
old_width = photo.shape[2]
|
49 |
+
new_height = old_height
|
50 |
+
new_width = old_width
|
51 |
+
radio = old_height / old_width
|
52 |
+
max_side = 1000
|
53 |
+
if old_height > old_width:
|
54 |
+
new_height = max_side
|
55 |
+
new_width = int(new_height / radio)
|
56 |
+
else:
|
57 |
+
new_width = max_side
|
58 |
+
new_height = int(new_width * radio)
|
59 |
+
|
60 |
+
if new_width % 8 != 0 or new_height % 8 != 0:
|
61 |
+
new_width = new_width // 8 * 8
|
62 |
+
new_height = new_height // 8 * 8
|
63 |
+
|
64 |
+
photo = torchvision.transforms.Resize((new_height, new_width))(photo)
|
65 |
+
|
66 |
+
required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
|
67 |
+
prompts = {
|
68 |
+
"albedo": "Albedo (diffuse basecolor)",
|
69 |
+
"normal": "Camera-space Normal",
|
70 |
+
"roughness": "Roughness",
|
71 |
+
"metallic": "Metallicness",
|
72 |
+
"irradiance": "Irradiance (diffuse lighting)",
|
73 |
+
}
|
74 |
+
|
75 |
+
return_list = []
|
76 |
+
for i in range(num_samples):
|
77 |
+
for aov_name in required_aovs:
|
78 |
+
prompt = prompts[aov_name]
|
79 |
+
generated_image = pipe(
|
80 |
+
prompt=prompt,
|
81 |
+
photo=photo,
|
82 |
+
num_inference_steps=inference_step,
|
83 |
+
height=new_height,
|
84 |
+
width=new_width,
|
85 |
+
generator=generator,
|
86 |
+
required_aovs=[aov_name],
|
87 |
+
).images[0][0]
|
88 |
+
|
89 |
+
generated_image = torchvision.transforms.Resize(
|
90 |
+
(old_height, old_width)
|
91 |
+
)(generated_image)
|
92 |
+
|
93 |
+
generated_image = (generated_image, f"Generated {aov_name} {i}")
|
94 |
+
return_list.append(generated_image)
|
95 |
+
|
96 |
+
return return_list
|
97 |
+
|
98 |
+
block = gr.Blocks()
|
99 |
+
with block:
|
100 |
+
with gr.Row():
|
101 |
+
gr.Markdown("## Model RGB -> X (Realistic image -> Intrinsic channels)")
|
102 |
+
with gr.Row():
|
103 |
+
# Input side
|
104 |
+
with gr.Column():
|
105 |
+
gr.Markdown("### Given Image")
|
106 |
+
photo = gr.File(label="Photo", file_types=[".exr", ".png", ".jpg"])
|
107 |
+
|
108 |
+
gr.Markdown("### Parameters")
|
109 |
+
run_button = gr.Button(value="Run")
|
110 |
+
with gr.Accordion("Advanced options", open=False):
|
111 |
+
seed = gr.Slider(
|
112 |
+
label="Seed",
|
113 |
+
minimum=-1,
|
114 |
+
maximum=2147483647,
|
115 |
+
step=1,
|
116 |
+
randomize=True,
|
117 |
+
)
|
118 |
+
inference_step = gr.Slider(
|
119 |
+
label="Inference Step",
|
120 |
+
minimum=1,
|
121 |
+
maximum=100,
|
122 |
+
step=1,
|
123 |
+
value=50,
|
124 |
+
)
|
125 |
+
num_samples = gr.Slider(
|
126 |
+
label="Samples",
|
127 |
+
minimum=1,
|
128 |
+
maximum=100,
|
129 |
+
step=1,
|
130 |
+
value=1,
|
131 |
+
)
|
132 |
+
|
133 |
+
# Output side
|
134 |
+
with gr.Column():
|
135 |
+
gr.Markdown("### Output Gallery")
|
136 |
+
result_gallery = gr.Gallery(
|
137 |
+
label="Output",
|
138 |
+
show_label=False,
|
139 |
+
elem_id="gallery",
|
140 |
+
columns=2,
|
141 |
+
)
|
142 |
+
|
143 |
+
inputs = [
|
144 |
+
photo,
|
145 |
+
seed,
|
146 |
+
inference_step,
|
147 |
+
num_samples,
|
148 |
+
]
|
149 |
+
run_button.click(fn=callback, inputs=inputs, outputs=result_gallery, queue=True)
|
150 |
+
|
151 |
+
return block
|
152 |
+
|
153 |
+
|
154 |
+
if __name__ == "__main__":
|
155 |
+
demo = get_rgb2x_demo()
|
156 |
+
demo.queue(max_size=1)
|
157 |
+
demo.launch()
|
rgb2x/load_image.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
|
6 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
def convert_rgb_2_XYZ(rgb):
|
11 |
+
# Reference: https://web.archive.org/web/20191027010220/http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html
|
12 |
+
# rgb: (h, w, 3)
|
13 |
+
# XYZ: (h, w, 3)
|
14 |
+
XYZ = torch.ones_like(rgb)
|
15 |
+
XYZ[:, :, 0] = (
|
16 |
+
0.4124564 * rgb[:, :, 0] + 0.3575761 * rgb[:, :, 1] + 0.1804375 * rgb[:, :, 2]
|
17 |
+
)
|
18 |
+
XYZ[:, :, 1] = (
|
19 |
+
0.2126729 * rgb[:, :, 0] + 0.7151522 * rgb[:, :, 1] + 0.0721750 * rgb[:, :, 2]
|
20 |
+
)
|
21 |
+
XYZ[:, :, 2] = (
|
22 |
+
0.0193339 * rgb[:, :, 0] + 0.1191920 * rgb[:, :, 1] + 0.9503041 * rgb[:, :, 2]
|
23 |
+
)
|
24 |
+
return XYZ
|
25 |
+
|
26 |
+
|
27 |
+
def convert_XYZ_2_Yxy(XYZ):
|
28 |
+
# XYZ: (h, w, 3)
|
29 |
+
# Yxy: (h, w, 3)
|
30 |
+
Yxy = torch.ones_like(XYZ)
|
31 |
+
Yxy[:, :, 0] = XYZ[:, :, 1]
|
32 |
+
sum = torch.sum(XYZ, dim=2)
|
33 |
+
inv_sum = 1.0 / torch.clamp(sum, min=1e-4)
|
34 |
+
Yxy[:, :, 1] = XYZ[:, :, 0] * inv_sum
|
35 |
+
Yxy[:, :, 2] = XYZ[:, :, 1] * inv_sum
|
36 |
+
return Yxy
|
37 |
+
|
38 |
+
|
39 |
+
def convert_rgb_2_Yxy(rgb):
|
40 |
+
# rgb: (h, w, 3)
|
41 |
+
# Yxy: (h, w, 3)
|
42 |
+
return convert_XYZ_2_Yxy(convert_rgb_2_XYZ(rgb))
|
43 |
+
|
44 |
+
|
45 |
+
def convert_XYZ_2_rgb(XYZ):
|
46 |
+
# XYZ: (h, w, 3)
|
47 |
+
# rgb: (h, w, 3)
|
48 |
+
rgb = torch.ones_like(XYZ)
|
49 |
+
rgb[:, :, 0] = (
|
50 |
+
3.2404542 * XYZ[:, :, 0] - 1.5371385 * XYZ[:, :, 1] - 0.4985314 * XYZ[:, :, 2]
|
51 |
+
)
|
52 |
+
rgb[:, :, 1] = (
|
53 |
+
-0.9692660 * XYZ[:, :, 0] + 1.8760108 * XYZ[:, :, 1] + 0.0415560 * XYZ[:, :, 2]
|
54 |
+
)
|
55 |
+
rgb[:, :, 2] = (
|
56 |
+
0.0556434 * XYZ[:, :, 0] - 0.2040259 * XYZ[:, :, 1] + 1.0572252 * XYZ[:, :, 2]
|
57 |
+
)
|
58 |
+
return rgb
|
59 |
+
|
60 |
+
|
61 |
+
def convert_Yxy_2_XYZ(Yxy):
|
62 |
+
# Yxy: (h, w, 3)
|
63 |
+
# XYZ: (h, w, 3)
|
64 |
+
XYZ = torch.ones_like(Yxy)
|
65 |
+
XYZ[:, :, 0] = Yxy[:, :, 1] / torch.clamp(Yxy[:, :, 2], min=1e-6) * Yxy[:, :, 0]
|
66 |
+
XYZ[:, :, 1] = Yxy[:, :, 0]
|
67 |
+
XYZ[:, :, 2] = (
|
68 |
+
(1.0 - Yxy[:, :, 1] - Yxy[:, :, 2])
|
69 |
+
/ torch.clamp(Yxy[:, :, 2], min=1e-4)
|
70 |
+
* Yxy[:, :, 0]
|
71 |
+
)
|
72 |
+
return XYZ
|
73 |
+
|
74 |
+
|
75 |
+
def convert_Yxy_2_rgb(Yxy):
|
76 |
+
# Yxy: (h, w, 3)
|
77 |
+
# rgb: (h, w, 3)
|
78 |
+
return convert_XYZ_2_rgb(convert_Yxy_2_XYZ(Yxy))
|
79 |
+
|
80 |
+
|
81 |
+
def load_ldr_image(image_path, from_srgb=False, clamp=False, normalize=False):
|
82 |
+
# Load png or jpg image
|
83 |
+
image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
84 |
+
image = torch.from_numpy(image.astype(np.float32) / 255.0) # (h, w, c)
|
85 |
+
image[~torch.isfinite(image)] = 0
|
86 |
+
if from_srgb:
|
87 |
+
# Convert from sRGB to linear RGB
|
88 |
+
image = image**2.2
|
89 |
+
if clamp:
|
90 |
+
image = torch.clamp(image, min=0.0, max=1.0)
|
91 |
+
if normalize:
|
92 |
+
# Normalize to [-1, 1]
|
93 |
+
image = image * 2.0 - 1.0
|
94 |
+
image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6)
|
95 |
+
return image.permute(2, 0, 1) # returns (c, h, w)
|
96 |
+
|
97 |
+
|
98 |
+
def load_exr_image(image_path, tonemaping=False, clamp=False, normalize=False):
|
99 |
+
image = cv2.cvtColor(cv2.imread(image_path, -1), cv2.COLOR_BGR2RGB)
|
100 |
+
image = torch.from_numpy(image.astype("float32")) # (h, w, c)
|
101 |
+
image[~torch.isfinite(image)] = 0
|
102 |
+
if tonemaping:
|
103 |
+
# Exposure adjuestment
|
104 |
+
image_Yxy = convert_rgb_2_Yxy(image)
|
105 |
+
lum = (
|
106 |
+
image[:, :, 0:1] * 0.2125
|
107 |
+
+ image[:, :, 1:2] * 0.7154
|
108 |
+
+ image[:, :, 2:3] * 0.0721
|
109 |
+
)
|
110 |
+
lum = torch.log(torch.clamp(lum, min=1e-6))
|
111 |
+
lum_mean = torch.exp(torch.mean(lum))
|
112 |
+
lp = image_Yxy[:, :, 0:1] * 0.18 / torch.clamp(lum_mean, min=1e-6)
|
113 |
+
image_Yxy[:, :, 0:1] = lp
|
114 |
+
image = convert_Yxy_2_rgb(image_Yxy)
|
115 |
+
if clamp:
|
116 |
+
image = torch.clamp(image, min=0.0, max=1.0)
|
117 |
+
if normalize:
|
118 |
+
image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6)
|
119 |
+
return image.permute(2, 0, 1) # returns (c, h, w)
|
rgb2x/pipeline_rgb2x.py
ADDED
@@ -0,0 +1,821 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Callable, List, Optional, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import PIL
|
7 |
+
import torch
|
8 |
+
from diffusers.configuration_utils import register_to_config
|
9 |
+
from diffusers.image_processor import VaeImageProcessor
|
10 |
+
from diffusers.loaders import (
|
11 |
+
LoraLoaderMixin,
|
12 |
+
TextualInversionLoaderMixin,
|
13 |
+
)
|
14 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
15 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
16 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
17 |
+
rescale_noise_cfg,
|
18 |
+
)
|
19 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
20 |
+
from diffusers.utils import (
|
21 |
+
CONFIG_NAME,
|
22 |
+
BaseOutput,
|
23 |
+
deprecate,
|
24 |
+
logging,
|
25 |
+
)
|
26 |
+
from diffusers.utils.torch_utils import randn_tensor
|
27 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
28 |
+
|
29 |
+
logger = logging.get_logger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
class VaeImageProcrssorAOV(VaeImageProcessor):
|
33 |
+
"""
|
34 |
+
Image processor for VAE AOV.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
38 |
+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
|
39 |
+
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
40 |
+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
41 |
+
resample (`str`, *optional*, defaults to `lanczos`):
|
42 |
+
Resampling filter to use when resizing the image.
|
43 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
44 |
+
Whether to normalize the image to [-1,1].
|
45 |
+
"""
|
46 |
+
|
47 |
+
config_name = CONFIG_NAME
|
48 |
+
|
49 |
+
@register_to_config
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
do_resize: bool = True,
|
53 |
+
vae_scale_factor: int = 8,
|
54 |
+
resample: str = "lanczos",
|
55 |
+
do_normalize: bool = True,
|
56 |
+
):
|
57 |
+
super().__init__()
|
58 |
+
|
59 |
+
def postprocess(
|
60 |
+
self,
|
61 |
+
image: torch.FloatTensor,
|
62 |
+
output_type: str = "pil",
|
63 |
+
do_denormalize: Optional[List[bool]] = None,
|
64 |
+
do_gamma_correction: bool = True,
|
65 |
+
):
|
66 |
+
if not isinstance(image, torch.Tensor):
|
67 |
+
raise ValueError(
|
68 |
+
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
69 |
+
)
|
70 |
+
if output_type not in ["latent", "pt", "np", "pil"]:
|
71 |
+
deprecation_message = (
|
72 |
+
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
73 |
+
"`pil`, `np`, `pt`, `latent`"
|
74 |
+
)
|
75 |
+
deprecate(
|
76 |
+
"Unsupported output_type",
|
77 |
+
"1.0.0",
|
78 |
+
deprecation_message,
|
79 |
+
standard_warn=False,
|
80 |
+
)
|
81 |
+
output_type = "np"
|
82 |
+
|
83 |
+
if output_type == "latent":
|
84 |
+
return image
|
85 |
+
|
86 |
+
if do_denormalize is None:
|
87 |
+
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
88 |
+
|
89 |
+
image = torch.stack(
|
90 |
+
[
|
91 |
+
self.denormalize(image[i]) if do_denormalize[i] else image[i]
|
92 |
+
for i in range(image.shape[0])
|
93 |
+
]
|
94 |
+
)
|
95 |
+
|
96 |
+
# Gamma correction
|
97 |
+
if do_gamma_correction:
|
98 |
+
image = torch.pow(image, 1.0 / 2.2)
|
99 |
+
|
100 |
+
if output_type == "pt":
|
101 |
+
return image
|
102 |
+
|
103 |
+
image = self.pt_to_numpy(image)
|
104 |
+
|
105 |
+
if output_type == "np":
|
106 |
+
return image
|
107 |
+
|
108 |
+
if output_type == "pil":
|
109 |
+
return self.numpy_to_pil(image)
|
110 |
+
|
111 |
+
def preprocess_normal(
|
112 |
+
self,
|
113 |
+
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
|
114 |
+
height: Optional[int] = None,
|
115 |
+
width: Optional[int] = None,
|
116 |
+
) -> torch.Tensor:
|
117 |
+
image = torch.stack([image], axis=0)
|
118 |
+
return image
|
119 |
+
|
120 |
+
|
121 |
+
@dataclass
|
122 |
+
class StableDiffusionAOVPipelineOutput(BaseOutput):
|
123 |
+
"""
|
124 |
+
Output class for Stable Diffusion AOV pipelines.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
128 |
+
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
|
129 |
+
num_channels)`.
|
130 |
+
nsfw_content_detected (`List[bool]`)
|
131 |
+
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
|
132 |
+
`None` if safety checking could not be performed.
|
133 |
+
"""
|
134 |
+
|
135 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
136 |
+
|
137 |
+
|
138 |
+
class StableDiffusionAOVMatEstPipeline(
|
139 |
+
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin
|
140 |
+
):
|
141 |
+
r"""
|
142 |
+
Pipeline for AOVs.
|
143 |
+
|
144 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
145 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
146 |
+
|
147 |
+
The pipeline also inherits the following loading methods:
|
148 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
149 |
+
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
150 |
+
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
151 |
+
|
152 |
+
Args:
|
153 |
+
vae ([`AutoencoderKL`]):
|
154 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
155 |
+
text_encoder ([`~transformers.CLIPTextModel`]):
|
156 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
157 |
+
tokenizer ([`~transformers.CLIPTokenizer`]):
|
158 |
+
A `CLIPTokenizer` to tokenize text.
|
159 |
+
unet ([`UNet2DConditionModel`]):
|
160 |
+
A `UNet2DConditionModel` to denoise the encoded image latents.
|
161 |
+
scheduler ([`SchedulerMixin`]):
|
162 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
163 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
164 |
+
"""
|
165 |
+
|
166 |
+
def __init__(
|
167 |
+
self,
|
168 |
+
vae: AutoencoderKL,
|
169 |
+
text_encoder: CLIPTextModel,
|
170 |
+
tokenizer: CLIPTokenizer,
|
171 |
+
unet: UNet2DConditionModel,
|
172 |
+
scheduler: KarrasDiffusionSchedulers,
|
173 |
+
):
|
174 |
+
super().__init__()
|
175 |
+
|
176 |
+
self.register_modules(
|
177 |
+
vae=vae,
|
178 |
+
text_encoder=text_encoder,
|
179 |
+
tokenizer=tokenizer,
|
180 |
+
unet=unet,
|
181 |
+
scheduler=scheduler,
|
182 |
+
)
|
183 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
184 |
+
self.image_processor = VaeImageProcrssorAOV(
|
185 |
+
vae_scale_factor=self.vae_scale_factor
|
186 |
+
)
|
187 |
+
self.register_to_config()
|
188 |
+
|
189 |
+
def _encode_prompt(
|
190 |
+
self,
|
191 |
+
prompt,
|
192 |
+
device,
|
193 |
+
num_images_per_prompt,
|
194 |
+
do_classifier_free_guidance,
|
195 |
+
negative_prompt=None,
|
196 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
197 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
198 |
+
):
|
199 |
+
r"""
|
200 |
+
Encodes the prompt into text encoder hidden states.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
prompt (`str` or `List[str]`, *optional*):
|
204 |
+
prompt to be encoded
|
205 |
+
device: (`torch.device`):
|
206 |
+
torch device
|
207 |
+
num_images_per_prompt (`int`):
|
208 |
+
number of images that should be generated per prompt
|
209 |
+
do_classifier_free_guidance (`bool`):
|
210 |
+
whether to use classifier free guidance or not
|
211 |
+
negative_ prompt (`str` or `List[str]`, *optional*):
|
212 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
213 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
214 |
+
less than `1`).
|
215 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
216 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
217 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
218 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
219 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
220 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
221 |
+
argument.
|
222 |
+
"""
|
223 |
+
if prompt is not None and isinstance(prompt, str):
|
224 |
+
batch_size = 1
|
225 |
+
elif prompt is not None and isinstance(prompt, list):
|
226 |
+
batch_size = len(prompt)
|
227 |
+
else:
|
228 |
+
batch_size = prompt_embeds.shape[0]
|
229 |
+
|
230 |
+
if prompt_embeds is None:
|
231 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
232 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
233 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
234 |
+
|
235 |
+
text_inputs = self.tokenizer(
|
236 |
+
prompt,
|
237 |
+
padding="max_length",
|
238 |
+
max_length=self.tokenizer.model_max_length,
|
239 |
+
truncation=True,
|
240 |
+
return_tensors="pt",
|
241 |
+
)
|
242 |
+
text_input_ids = text_inputs.input_ids
|
243 |
+
untruncated_ids = self.tokenizer(
|
244 |
+
prompt, padding="longest", return_tensors="pt"
|
245 |
+
).input_ids
|
246 |
+
|
247 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[
|
248 |
+
-1
|
249 |
+
] and not torch.equal(text_input_ids, untruncated_ids):
|
250 |
+
removed_text = self.tokenizer.batch_decode(
|
251 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
252 |
+
)
|
253 |
+
logger.warning(
|
254 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
255 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
256 |
+
)
|
257 |
+
|
258 |
+
if (
|
259 |
+
hasattr(self.text_encoder.config, "use_attention_mask")
|
260 |
+
and self.text_encoder.config.use_attention_mask
|
261 |
+
):
|
262 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
263 |
+
else:
|
264 |
+
attention_mask = None
|
265 |
+
|
266 |
+
prompt_embeds = self.text_encoder(
|
267 |
+
text_input_ids.to(device),
|
268 |
+
attention_mask=attention_mask,
|
269 |
+
)
|
270 |
+
prompt_embeds = prompt_embeds[0]
|
271 |
+
|
272 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
273 |
+
|
274 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
275 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
276 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
277 |
+
prompt_embeds = prompt_embeds.view(
|
278 |
+
bs_embed * num_images_per_prompt, seq_len, -1
|
279 |
+
)
|
280 |
+
|
281 |
+
# get unconditional embeddings for classifier free guidance
|
282 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
283 |
+
uncond_tokens: List[str]
|
284 |
+
if negative_prompt is None:
|
285 |
+
uncond_tokens = [""] * batch_size
|
286 |
+
elif type(prompt) is not type(negative_prompt):
|
287 |
+
raise TypeError(
|
288 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
289 |
+
f" {type(prompt)}."
|
290 |
+
)
|
291 |
+
elif isinstance(negative_prompt, str):
|
292 |
+
uncond_tokens = [negative_prompt]
|
293 |
+
elif batch_size != len(negative_prompt):
|
294 |
+
raise ValueError(
|
295 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
296 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
297 |
+
" the batch size of `prompt`."
|
298 |
+
)
|
299 |
+
else:
|
300 |
+
uncond_tokens = negative_prompt
|
301 |
+
|
302 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
303 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
304 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
305 |
+
|
306 |
+
max_length = prompt_embeds.shape[1]
|
307 |
+
uncond_input = self.tokenizer(
|
308 |
+
uncond_tokens,
|
309 |
+
padding="max_length",
|
310 |
+
max_length=max_length,
|
311 |
+
truncation=True,
|
312 |
+
return_tensors="pt",
|
313 |
+
)
|
314 |
+
|
315 |
+
if (
|
316 |
+
hasattr(self.text_encoder.config, "use_attention_mask")
|
317 |
+
and self.text_encoder.config.use_attention_mask
|
318 |
+
):
|
319 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
320 |
+
else:
|
321 |
+
attention_mask = None
|
322 |
+
|
323 |
+
negative_prompt_embeds = self.text_encoder(
|
324 |
+
uncond_input.input_ids.to(device),
|
325 |
+
attention_mask=attention_mask,
|
326 |
+
)
|
327 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
328 |
+
|
329 |
+
if do_classifier_free_guidance:
|
330 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
331 |
+
seq_len = negative_prompt_embeds.shape[1]
|
332 |
+
|
333 |
+
negative_prompt_embeds = negative_prompt_embeds.to(
|
334 |
+
dtype=self.text_encoder.dtype, device=device
|
335 |
+
)
|
336 |
+
|
337 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
338 |
+
1, num_images_per_prompt, 1
|
339 |
+
)
|
340 |
+
negative_prompt_embeds = negative_prompt_embeds.view(
|
341 |
+
batch_size * num_images_per_prompt, seq_len, -1
|
342 |
+
)
|
343 |
+
|
344 |
+
# For classifier free guidance, we need to do two forward passes.
|
345 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
346 |
+
# to avoid doing two forward passes
|
347 |
+
# pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
|
348 |
+
prompt_embeds = torch.cat(
|
349 |
+
[prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
|
350 |
+
)
|
351 |
+
|
352 |
+
return prompt_embeds
|
353 |
+
|
354 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
355 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
356 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
357 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
358 |
+
# and should be between [0, 1]
|
359 |
+
|
360 |
+
accepts_eta = "eta" in set(
|
361 |
+
inspect.signature(self.scheduler.step).parameters.keys()
|
362 |
+
)
|
363 |
+
extra_step_kwargs = {}
|
364 |
+
if accepts_eta:
|
365 |
+
extra_step_kwargs["eta"] = eta
|
366 |
+
|
367 |
+
# check if the scheduler accepts generator
|
368 |
+
accepts_generator = "generator" in set(
|
369 |
+
inspect.signature(self.scheduler.step).parameters.keys()
|
370 |
+
)
|
371 |
+
if accepts_generator:
|
372 |
+
extra_step_kwargs["generator"] = generator
|
373 |
+
return extra_step_kwargs
|
374 |
+
|
375 |
+
def check_inputs(
|
376 |
+
self,
|
377 |
+
prompt,
|
378 |
+
callback_steps,
|
379 |
+
negative_prompt=None,
|
380 |
+
prompt_embeds=None,
|
381 |
+
negative_prompt_embeds=None,
|
382 |
+
):
|
383 |
+
if (callback_steps is None) or (
|
384 |
+
callback_steps is not None
|
385 |
+
and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
386 |
+
):
|
387 |
+
raise ValueError(
|
388 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
389 |
+
f" {type(callback_steps)}."
|
390 |
+
)
|
391 |
+
|
392 |
+
if prompt is not None and prompt_embeds is not None:
|
393 |
+
raise ValueError(
|
394 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
395 |
+
" only forward one of the two."
|
396 |
+
)
|
397 |
+
elif prompt is None and prompt_embeds is None:
|
398 |
+
raise ValueError(
|
399 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
400 |
+
)
|
401 |
+
elif prompt is not None and (
|
402 |
+
not isinstance(prompt, str) and not isinstance(prompt, list)
|
403 |
+
):
|
404 |
+
raise ValueError(
|
405 |
+
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
406 |
+
)
|
407 |
+
|
408 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
409 |
+
raise ValueError(
|
410 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
411 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
412 |
+
)
|
413 |
+
|
414 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
415 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
416 |
+
raise ValueError(
|
417 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
418 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
419 |
+
f" {negative_prompt_embeds.shape}."
|
420 |
+
)
|
421 |
+
|
422 |
+
def prepare_latents(
|
423 |
+
self,
|
424 |
+
batch_size,
|
425 |
+
num_channels_latents,
|
426 |
+
height,
|
427 |
+
width,
|
428 |
+
dtype,
|
429 |
+
device,
|
430 |
+
generator,
|
431 |
+
latents=None,
|
432 |
+
):
|
433 |
+
shape = (
|
434 |
+
batch_size,
|
435 |
+
num_channels_latents,
|
436 |
+
height // self.vae_scale_factor,
|
437 |
+
width // self.vae_scale_factor,
|
438 |
+
)
|
439 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
440 |
+
raise ValueError(
|
441 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
442 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
443 |
+
)
|
444 |
+
|
445 |
+
if latents is None:
|
446 |
+
latents = randn_tensor(
|
447 |
+
shape, generator=generator, device=device, dtype=dtype
|
448 |
+
)
|
449 |
+
else:
|
450 |
+
latents = latents.to(device)
|
451 |
+
|
452 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
453 |
+
latents = latents * self.scheduler.init_noise_sigma
|
454 |
+
return latents
|
455 |
+
|
456 |
+
def prepare_image_latents(
|
457 |
+
self,
|
458 |
+
image,
|
459 |
+
batch_size,
|
460 |
+
num_images_per_prompt,
|
461 |
+
dtype,
|
462 |
+
device,
|
463 |
+
do_classifier_free_guidance,
|
464 |
+
generator=None,
|
465 |
+
):
|
466 |
+
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
467 |
+
raise ValueError(
|
468 |
+
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
469 |
+
)
|
470 |
+
|
471 |
+
image = image.to(device=device, dtype=dtype)
|
472 |
+
|
473 |
+
batch_size = batch_size * num_images_per_prompt
|
474 |
+
|
475 |
+
if image.shape[1] == 4:
|
476 |
+
image_latents = image
|
477 |
+
else:
|
478 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
479 |
+
raise ValueError(
|
480 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
481 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
482 |
+
)
|
483 |
+
|
484 |
+
if isinstance(generator, list):
|
485 |
+
image_latents = [
|
486 |
+
self.vae.encode(image[i : i + 1]).latent_dist.mode()
|
487 |
+
for i in range(batch_size)
|
488 |
+
]
|
489 |
+
image_latents = torch.cat(image_latents, dim=0)
|
490 |
+
else:
|
491 |
+
image_latents = self.vae.encode(image).latent_dist.mode()
|
492 |
+
|
493 |
+
if (
|
494 |
+
batch_size > image_latents.shape[0]
|
495 |
+
and batch_size % image_latents.shape[0] == 0
|
496 |
+
):
|
497 |
+
# expand image_latents for batch_size
|
498 |
+
deprecation_message = (
|
499 |
+
f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
|
500 |
+
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
|
501 |
+
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
502 |
+
" your script to pass as many initial images as text prompts to suppress this warning."
|
503 |
+
)
|
504 |
+
deprecate(
|
505 |
+
"len(prompt) != len(image)",
|
506 |
+
"1.0.0",
|
507 |
+
deprecation_message,
|
508 |
+
standard_warn=False,
|
509 |
+
)
|
510 |
+
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
511 |
+
image_latents = torch.cat(
|
512 |
+
[image_latents] * additional_image_per_prompt, dim=0
|
513 |
+
)
|
514 |
+
elif (
|
515 |
+
batch_size > image_latents.shape[0]
|
516 |
+
and batch_size % image_latents.shape[0] != 0
|
517 |
+
):
|
518 |
+
raise ValueError(
|
519 |
+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
520 |
+
)
|
521 |
+
else:
|
522 |
+
image_latents = torch.cat([image_latents], dim=0)
|
523 |
+
|
524 |
+
if do_classifier_free_guidance:
|
525 |
+
uncond_image_latents = torch.zeros_like(image_latents)
|
526 |
+
image_latents = torch.cat(
|
527 |
+
[image_latents, image_latents, uncond_image_latents], dim=0
|
528 |
+
)
|
529 |
+
|
530 |
+
return image_latents
|
531 |
+
|
532 |
+
@torch.no_grad()
|
533 |
+
def __call__(
|
534 |
+
self,
|
535 |
+
prompt: Union[str, List[str]] = None,
|
536 |
+
photo: Union[
|
537 |
+
torch.FloatTensor,
|
538 |
+
PIL.Image.Image,
|
539 |
+
np.ndarray,
|
540 |
+
List[torch.FloatTensor],
|
541 |
+
List[PIL.Image.Image],
|
542 |
+
List[np.ndarray],
|
543 |
+
] = None,
|
544 |
+
height: Optional[int] = None,
|
545 |
+
width: Optional[int] = None,
|
546 |
+
num_inference_steps: int = 100,
|
547 |
+
required_aovs: List[str] = ["albedo"],
|
548 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
549 |
+
num_images_per_prompt: Optional[int] = 1,
|
550 |
+
use_default_scaling_factor: Optional[bool] = False,
|
551 |
+
guidance_scale: float = 0.0,
|
552 |
+
image_guidance_scale: float = 0.0,
|
553 |
+
guidance_rescale: float = 0.0,
|
554 |
+
eta: float = 0.0,
|
555 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
556 |
+
latents: Optional[torch.FloatTensor] = None,
|
557 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
558 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
559 |
+
output_type: Optional[str] = "pil",
|
560 |
+
return_dict: bool = True,
|
561 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
562 |
+
callback_steps: int = 1,
|
563 |
+
):
|
564 |
+
r"""
|
565 |
+
The call function to the pipeline for generation.
|
566 |
+
|
567 |
+
Args:
|
568 |
+
prompt (`str` or `List[str]`, *optional*):
|
569 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
570 |
+
image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
571 |
+
`Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept
|
572 |
+
image latents as `image`, but if passing latents directly it is not encoded again.
|
573 |
+
num_inference_steps (`int`, *optional*, defaults to 100):
|
574 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
575 |
+
expense of slower inference.
|
576 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
577 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
578 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
579 |
+
image_guidance_scale (`float`, *optional*, defaults to 1.5):
|
580 |
+
Push the generated image towards the inital `image`. Image guidance scale is enabled by setting
|
581 |
+
`image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely
|
582 |
+
linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a
|
583 |
+
value of at least `1`.
|
584 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
585 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
586 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
587 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
588 |
+
The number of images to generate per prompt.
|
589 |
+
eta (`float`, *optional*, defaults to 0.0):
|
590 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
591 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
592 |
+
generator (`torch.Generator`, *optional*):
|
593 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
594 |
+
generation deterministic.
|
595 |
+
latents (`torch.FloatTensor`, *optional*):
|
596 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
597 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
598 |
+
tensor is generated by sampling using the supplied random `generator`.
|
599 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
600 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
601 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
602 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
603 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
604 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
605 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
606 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
607 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
608 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
609 |
+
plain tuple.
|
610 |
+
callback (`Callable`, *optional*):
|
611 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
612 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
613 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
614 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
615 |
+
every step.
|
616 |
+
|
617 |
+
Examples:
|
618 |
+
|
619 |
+
```py
|
620 |
+
>>> import PIL
|
621 |
+
>>> import requests
|
622 |
+
>>> import torch
|
623 |
+
>>> from io import BytesIO
|
624 |
+
|
625 |
+
>>> from diffusers import StableDiffusionInstructPix2PixPipeline
|
626 |
+
|
627 |
+
|
628 |
+
>>> def download_image(url):
|
629 |
+
... response = requests.get(url)
|
630 |
+
... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
|
631 |
+
|
632 |
+
|
633 |
+
>>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png"
|
634 |
+
|
635 |
+
>>> image = download_image(img_url).resize((512, 512))
|
636 |
+
|
637 |
+
>>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
638 |
+
... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16
|
639 |
+
... )
|
640 |
+
>>> pipe = pipe.to("cuda")
|
641 |
+
|
642 |
+
>>> prompt = "make the mountains snowy"
|
643 |
+
>>> image = pipe(prompt=prompt, image=image).images[0]
|
644 |
+
```
|
645 |
+
|
646 |
+
Returns:
|
647 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
648 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
649 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
650 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
651 |
+
"not-safe-for-work" (nsfw) content.
|
652 |
+
"""
|
653 |
+
# 0. Check inputs
|
654 |
+
self.check_inputs(
|
655 |
+
prompt,
|
656 |
+
callback_steps,
|
657 |
+
negative_prompt,
|
658 |
+
prompt_embeds,
|
659 |
+
negative_prompt_embeds,
|
660 |
+
)
|
661 |
+
|
662 |
+
# 1. Define call parameters
|
663 |
+
if prompt is not None and isinstance(prompt, str):
|
664 |
+
batch_size = 1
|
665 |
+
elif prompt is not None and isinstance(prompt, list):
|
666 |
+
batch_size = len(prompt)
|
667 |
+
else:
|
668 |
+
batch_size = prompt_embeds.shape[0]
|
669 |
+
|
670 |
+
device = self._execution_device
|
671 |
+
do_classifier_free_guidance = (
|
672 |
+
guidance_scale > 1.0 and image_guidance_scale >= 1.0
|
673 |
+
)
|
674 |
+
# check if scheduler is in sigmas space
|
675 |
+
scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
|
676 |
+
|
677 |
+
# 2. Encode input prompt
|
678 |
+
prompt_embeds = self._encode_prompt(
|
679 |
+
prompt,
|
680 |
+
device,
|
681 |
+
num_images_per_prompt,
|
682 |
+
do_classifier_free_guidance,
|
683 |
+
negative_prompt,
|
684 |
+
prompt_embeds=prompt_embeds,
|
685 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
686 |
+
)
|
687 |
+
|
688 |
+
# 3. Preprocess image
|
689 |
+
# Normalize image to [-1,1]
|
690 |
+
preprocessed_photo = self.image_processor.preprocess(photo)
|
691 |
+
|
692 |
+
# 4. set timesteps
|
693 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
694 |
+
timesteps = self.scheduler.timesteps
|
695 |
+
|
696 |
+
# 5. Prepare Image latents
|
697 |
+
image_latents = self.prepare_image_latents(
|
698 |
+
preprocessed_photo,
|
699 |
+
batch_size,
|
700 |
+
num_images_per_prompt,
|
701 |
+
prompt_embeds.dtype,
|
702 |
+
device,
|
703 |
+
do_classifier_free_guidance,
|
704 |
+
generator,
|
705 |
+
)
|
706 |
+
image_latents = image_latents * self.vae.config.scaling_factor
|
707 |
+
|
708 |
+
height, width = image_latents.shape[-2:]
|
709 |
+
height = height * self.vae_scale_factor
|
710 |
+
width = width * self.vae_scale_factor
|
711 |
+
|
712 |
+
# 6. Prepare latent variables
|
713 |
+
num_channels_latents = self.unet.config.out_channels
|
714 |
+
latents = self.prepare_latents(
|
715 |
+
batch_size * num_images_per_prompt,
|
716 |
+
num_channels_latents,
|
717 |
+
height,
|
718 |
+
width,
|
719 |
+
prompt_embeds.dtype,
|
720 |
+
device,
|
721 |
+
generator,
|
722 |
+
latents,
|
723 |
+
)
|
724 |
+
|
725 |
+
# 7. Check that shapes of latents and image match the UNet channels
|
726 |
+
num_channels_image = image_latents.shape[1]
|
727 |
+
if num_channels_latents + num_channels_image != self.unet.config.in_channels:
|
728 |
+
raise ValueError(
|
729 |
+
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
730 |
+
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
731 |
+
f" `num_channels_image`: {num_channels_image} "
|
732 |
+
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
|
733 |
+
" `pipeline.unet` or your `image` input."
|
734 |
+
)
|
735 |
+
|
736 |
+
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
737 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
738 |
+
|
739 |
+
# 9. Denoising loop
|
740 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
741 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
742 |
+
for i, t in enumerate(timesteps):
|
743 |
+
# Expand the latents if we are doing classifier free guidance.
|
744 |
+
# The latents are expanded 3 times because for pix2pix the guidance\
|
745 |
+
# is applied for both the text and the input image.
|
746 |
+
latent_model_input = (
|
747 |
+
torch.cat([latents] * 3) if do_classifier_free_guidance else latents
|
748 |
+
)
|
749 |
+
|
750 |
+
# concat latents, image_latents in the channel dimension
|
751 |
+
scaled_latent_model_input = self.scheduler.scale_model_input(
|
752 |
+
latent_model_input, t
|
753 |
+
)
|
754 |
+
scaled_latent_model_input = torch.cat(
|
755 |
+
[scaled_latent_model_input, image_latents], dim=1
|
756 |
+
)
|
757 |
+
|
758 |
+
# predict the noise residual
|
759 |
+
noise_pred = self.unet(
|
760 |
+
scaled_latent_model_input,
|
761 |
+
t,
|
762 |
+
encoder_hidden_states=prompt_embeds,
|
763 |
+
return_dict=False,
|
764 |
+
)[0]
|
765 |
+
|
766 |
+
# perform guidance
|
767 |
+
if do_classifier_free_guidance:
|
768 |
+
(
|
769 |
+
noise_pred_text,
|
770 |
+
noise_pred_image,
|
771 |
+
noise_pred_uncond,
|
772 |
+
) = noise_pred.chunk(3)
|
773 |
+
noise_pred = (
|
774 |
+
noise_pred_uncond
|
775 |
+
+ guidance_scale * (noise_pred_text - noise_pred_image)
|
776 |
+
+ image_guidance_scale * (noise_pred_image - noise_pred_uncond)
|
777 |
+
)
|
778 |
+
|
779 |
+
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
780 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
781 |
+
noise_pred = rescale_noise_cfg(
|
782 |
+
noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
|
783 |
+
)
|
784 |
+
|
785 |
+
# compute the previous noisy sample x_t -> x_t-1
|
786 |
+
latents = self.scheduler.step(
|
787 |
+
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
788 |
+
)[0]
|
789 |
+
|
790 |
+
# call the callback, if provided
|
791 |
+
if i == len(timesteps) - 1 or (
|
792 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
793 |
+
):
|
794 |
+
progress_bar.update()
|
795 |
+
if callback is not None and i % callback_steps == 0:
|
796 |
+
callback(i, t, latents)
|
797 |
+
|
798 |
+
aov_latents = latents / self.vae.config.scaling_factor
|
799 |
+
aov = self.vae.decode(aov_latents, return_dict=False)[0]
|
800 |
+
do_denormalize = [True] * aov.shape[0]
|
801 |
+
aov_name = required_aovs[0]
|
802 |
+
if aov_name == "albedo" or aov_name == "irradiance":
|
803 |
+
do_gamma_correction = True
|
804 |
+
else:
|
805 |
+
do_gamma_correction = False
|
806 |
+
|
807 |
+
if aov_name == "roughness" or aov_name == "metallic":
|
808 |
+
aov = aov[:, 0:1].repeat(1, 3, 1, 1)
|
809 |
+
|
810 |
+
aov = self.image_processor.postprocess(
|
811 |
+
aov,
|
812 |
+
output_type=output_type,
|
813 |
+
do_denormalize=do_denormalize,
|
814 |
+
do_gamma_correction=do_gamma_correction,
|
815 |
+
)
|
816 |
+
aovs = [aov]
|
817 |
+
|
818 |
+
# Offload last model to CPU
|
819 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
820 |
+
self.final_offload_hook.offload()
|
821 |
+
return StableDiffusionAOVPipelineOutput(images=aovs)
|
run.sh
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
CONDA_ENV=$(head -1 /code/environment.yml | cut -d" " -f2)
|
4 |
+
eval "$(conda shell.bash hook)"
|
5 |
+
conda activate gradio-user
|
6 |
+
export OMP_NUM_THREADS=4 # default is a wrong value: 7500m
|
7 |
+
|
8 |
+
conda install -n gradio-user pytorch3d=0.7.7 -c pytorch3d -c conda-forge
|
9 |
+
conda install -n gradio-user -c conda-forge open-clip-torch pytorch-lightning
|
10 |
+
|
11 |
+
# Start app.py
|
12 |
+
echo "Starting app.py..."
|
13 |
+
python -c "import torch; x=torch.rand(1, device='cuda'); print(x, x.device.type)"
|
14 |
+
python app.py
|
settings.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
DEFAULT_MODEL_ID = os.getenv("DEFAULT_MODEL_ID", "stable-diffusion-v1-5/stable-diffusion-v1-5")
|
6 |
+
|
7 |
+
MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "3"))
|
8 |
+
DEFAULT_NUM_IMAGES = min(MAX_NUM_IMAGES, int(os.getenv("DEFAULT_NUM_IMAGES", "1")))
|
9 |
+
MAX_IMAGE_RESOLUTION = int(os.getenv("MAX_IMAGE_RESOLUTION", "2048"))
|
10 |
+
DEFAULT_IMAGE_RESOLUTION = min(MAX_IMAGE_RESOLUTION, int(os.getenv("DEFAULT_IMAGE_RESOLUTION", "1024")))
|
11 |
+
|
12 |
+
ALLOW_CHANGING_BASE_MODEL = os.getenv("SPACE_ID") != "hysts/ControlNet-v1-1"
|
13 |
+
SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
|
14 |
+
|
15 |
+
MAX_SEED = np.iinfo(np.int32).max
|
16 |
+
|
17 |
+
# Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
|
18 |
+
|
19 |
+
# setup CUDA
|
20 |
+
# disable the following when deployting to hugging face
|
21 |
+
# if os.getenv("CUDA_VISIBLE_DEVICES") is None:
|
22 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = "7"
|
23 |
+
# os.environ["GRADIO_SERVER_PORT"] = "7864"
|
text2tex/lib/__init__.py
ADDED
File without changes
|
text2tex/lib/camera_helper.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
6 |
+
|
7 |
+
from pytorch3d.renderer import (
|
8 |
+
PerspectiveCameras,
|
9 |
+
look_at_view_transform
|
10 |
+
)
|
11 |
+
|
12 |
+
# customized
|
13 |
+
import sys
|
14 |
+
sys.path.append(".")
|
15 |
+
|
16 |
+
from lib.constants import VIEWPOINTS
|
17 |
+
|
18 |
+
# ---------------- UTILS ----------------------
|
19 |
+
|
20 |
+
def degree_to_radian(d):
|
21 |
+
return d * np.pi / 180
|
22 |
+
|
23 |
+
def radian_to_degree(r):
|
24 |
+
return 180 * r / np.pi
|
25 |
+
|
26 |
+
def xyz_to_polar(xyz):
|
27 |
+
""" assume y-axis is the up axis """
|
28 |
+
|
29 |
+
x, y, z = xyz
|
30 |
+
|
31 |
+
theta = 180 * np.arccos(z) / np.pi
|
32 |
+
phi = 180 * np.arccos(y) / np.pi
|
33 |
+
|
34 |
+
return theta, phi
|
35 |
+
|
36 |
+
def polar_to_xyz(theta, phi, dist):
|
37 |
+
""" assume y-axis is the up axis """
|
38 |
+
|
39 |
+
theta = degree_to_radian(theta)
|
40 |
+
phi = degree_to_radian(phi)
|
41 |
+
|
42 |
+
x = np.sin(phi) * np.sin(theta) * dist
|
43 |
+
y = np.cos(phi) * dist
|
44 |
+
z = np.sin(phi) * np.cos(theta) * dist
|
45 |
+
|
46 |
+
return [x, y, z]
|
47 |
+
|
48 |
+
|
49 |
+
# ---------------- VIEWPOINTS ----------------------
|
50 |
+
|
51 |
+
|
52 |
+
def filter_viewpoints(pre_viewpoints: dict, viewpoints: dict):
|
53 |
+
""" return the binary mask of viewpoints to be filtered """
|
54 |
+
|
55 |
+
filter_mask = [0 for _ in viewpoints.keys()]
|
56 |
+
for i, v in viewpoints.items():
|
57 |
+
x_v, y_v, z_v = polar_to_xyz(v["azim"], 90 - v["elev"], v["dist"])
|
58 |
+
|
59 |
+
for _, pv in pre_viewpoints.items():
|
60 |
+
x_pv, y_pv, z_pv = polar_to_xyz(pv["azim"], 90 - pv["elev"], pv["dist"])
|
61 |
+
sim = cosine_similarity(
|
62 |
+
np.array([[x_v, y_v, z_v]]),
|
63 |
+
np.array([[x_pv, y_pv, z_pv]])
|
64 |
+
)[0, 0]
|
65 |
+
|
66 |
+
if sim > 0.9:
|
67 |
+
filter_mask[i] = 1
|
68 |
+
|
69 |
+
return filter_mask
|
70 |
+
|
71 |
+
|
72 |
+
def init_viewpoints(mode, sample_space, init_dist, init_elev, principle_directions,
|
73 |
+
use_principle=True, use_shapenet=False, use_objaverse=False):
|
74 |
+
|
75 |
+
if mode == "predefined":
|
76 |
+
|
77 |
+
(
|
78 |
+
dist_list,
|
79 |
+
elev_list,
|
80 |
+
azim_list,
|
81 |
+
sector_list
|
82 |
+
) = init_predefined_viewpoints(sample_space, init_dist, init_elev)
|
83 |
+
|
84 |
+
elif mode == "hemisphere":
|
85 |
+
|
86 |
+
(
|
87 |
+
dist_list,
|
88 |
+
elev_list,
|
89 |
+
azim_list,
|
90 |
+
sector_list
|
91 |
+
) = init_hemisphere_viewpoints(sample_space, init_dist)
|
92 |
+
|
93 |
+
else:
|
94 |
+
raise NotImplementedError()
|
95 |
+
|
96 |
+
# punishments for views -> in case always selecting the same view
|
97 |
+
view_punishments = [1 for _ in range(len(dist_list))]
|
98 |
+
|
99 |
+
if use_principle:
|
100 |
+
|
101 |
+
(
|
102 |
+
dist_list,
|
103 |
+
elev_list,
|
104 |
+
azim_list,
|
105 |
+
sector_list,
|
106 |
+
view_punishments
|
107 |
+
) = init_principle_viewpoints(
|
108 |
+
principle_directions,
|
109 |
+
dist_list,
|
110 |
+
elev_list,
|
111 |
+
azim_list,
|
112 |
+
sector_list,
|
113 |
+
view_punishments,
|
114 |
+
use_shapenet,
|
115 |
+
use_objaverse
|
116 |
+
)
|
117 |
+
|
118 |
+
return dist_list, elev_list, azim_list, sector_list, view_punishments
|
119 |
+
|
120 |
+
|
121 |
+
def init_principle_viewpoints(
|
122 |
+
principle_directions,
|
123 |
+
dist_list,
|
124 |
+
elev_list,
|
125 |
+
azim_list,
|
126 |
+
sector_list,
|
127 |
+
view_punishments,
|
128 |
+
use_shapenet=False,
|
129 |
+
use_objaverse=False
|
130 |
+
):
|
131 |
+
|
132 |
+
if use_shapenet:
|
133 |
+
key = "shapenet"
|
134 |
+
|
135 |
+
pre_elev_list = [v for v in VIEWPOINTS[key]["elev"]]
|
136 |
+
pre_azim_list = [v for v in VIEWPOINTS[key]["azim"]]
|
137 |
+
pre_sector_list = [v for v in VIEWPOINTS[key]["sector"]]
|
138 |
+
|
139 |
+
num_principle = 10
|
140 |
+
pre_dist_list = [dist_list[0] for _ in range(num_principle)]
|
141 |
+
pre_view_punishments = [0 for _ in range(num_principle)]
|
142 |
+
|
143 |
+
elif use_objaverse:
|
144 |
+
key = "objaverse"
|
145 |
+
|
146 |
+
pre_elev_list = [v for v in VIEWPOINTS[key]["elev"]]
|
147 |
+
pre_azim_list = [v for v in VIEWPOINTS[key]["azim"]]
|
148 |
+
pre_sector_list = [v for v in VIEWPOINTS[key]["sector"]]
|
149 |
+
|
150 |
+
num_principle = 10
|
151 |
+
pre_dist_list = [dist_list[0] for _ in range(num_principle)]
|
152 |
+
pre_view_punishments = [0 for _ in range(num_principle)]
|
153 |
+
else:
|
154 |
+
num_principle = 6
|
155 |
+
pre_elev_list = [v for v in VIEWPOINTS[num_principle]["elev"]]
|
156 |
+
pre_azim_list = [v for v in VIEWPOINTS[num_principle]["azim"]]
|
157 |
+
pre_sector_list = [v for v in VIEWPOINTS[num_principle]["sector"]]
|
158 |
+
pre_dist_list = [dist_list[0] for _ in range(num_principle)]
|
159 |
+
pre_view_punishments = [0 for _ in range(num_principle)]
|
160 |
+
|
161 |
+
dist_list = pre_dist_list + dist_list
|
162 |
+
elev_list = pre_elev_list + elev_list
|
163 |
+
azim_list = pre_azim_list + azim_list
|
164 |
+
sector_list = pre_sector_list + sector_list
|
165 |
+
view_punishments = pre_view_punishments + view_punishments
|
166 |
+
|
167 |
+
return dist_list, elev_list, azim_list, sector_list, view_punishments
|
168 |
+
|
169 |
+
|
170 |
+
def init_predefined_viewpoints(sample_space, init_dist, init_elev):
|
171 |
+
|
172 |
+
viewpoints = VIEWPOINTS[sample_space]
|
173 |
+
|
174 |
+
assert sample_space == len(viewpoints["sector"])
|
175 |
+
|
176 |
+
dist_list = [init_dist for _ in range(sample_space)] # always the same dist
|
177 |
+
elev_list = [viewpoints["elev"][i] for i in range(sample_space)]
|
178 |
+
azim_list = [viewpoints["azim"][i] for i in range(sample_space)]
|
179 |
+
sector_list = [viewpoints["sector"][i] for i in range(sample_space)]
|
180 |
+
|
181 |
+
return dist_list, elev_list, azim_list, sector_list
|
182 |
+
|
183 |
+
|
184 |
+
def init_hemisphere_viewpoints(sample_space, init_dist):
|
185 |
+
"""
|
186 |
+
y is up-axis
|
187 |
+
"""
|
188 |
+
|
189 |
+
num_points = 2 * sample_space
|
190 |
+
ga = np.pi * (3. - np.sqrt(5.)) # golden angle in radians
|
191 |
+
|
192 |
+
flags = []
|
193 |
+
elev_list = [] # degree
|
194 |
+
azim_list = [] # degree
|
195 |
+
|
196 |
+
for i in range(num_points):
|
197 |
+
y = 1 - (i / float(num_points - 1)) * 2 # y goes from 1 to -1
|
198 |
+
|
199 |
+
# only take the north hemisphere
|
200 |
+
if y >= 0:
|
201 |
+
flags.append(True)
|
202 |
+
else:
|
203 |
+
flags.append(False)
|
204 |
+
|
205 |
+
theta = ga * i # golden angle increment
|
206 |
+
|
207 |
+
elev_list.append(radian_to_degree(np.arcsin(y)))
|
208 |
+
azim_list.append(radian_to_degree(theta))
|
209 |
+
|
210 |
+
radius = np.sqrt(1 - y * y) # radius at y
|
211 |
+
x = np.cos(theta) * radius
|
212 |
+
z = np.sin(theta) * radius
|
213 |
+
|
214 |
+
elev_list = [elev_list[i] for i in range(len(elev_list)) if flags[i]]
|
215 |
+
azim_list = [azim_list[i] for i in range(len(azim_list)) if flags[i]]
|
216 |
+
|
217 |
+
dist_list = [init_dist for _ in elev_list]
|
218 |
+
sector_list = ["good" for _ in elev_list] # HACK don't define sector names for now
|
219 |
+
|
220 |
+
return dist_list, elev_list, azim_list, sector_list
|
221 |
+
|
222 |
+
|
223 |
+
# ---------------- CAMERAS ----------------------
|
224 |
+
|
225 |
+
|
226 |
+
def init_camera(dist, elev, azim, image_size, device):
|
227 |
+
R, T = look_at_view_transform(dist, elev, azim)
|
228 |
+
image_size = torch.tensor([image_size, image_size]).unsqueeze(0)
|
229 |
+
cameras = PerspectiveCameras(R=R, T=T, device=device, image_size=image_size)
|
230 |
+
|
231 |
+
return cameras
|
text2tex/lib/constants.py
ADDED
@@ -0,0 +1,648 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PALETTE = {
|
2 |
+
0: [255, 255, 255], # white - background
|
3 |
+
1: [204, 50, 50], # red - old
|
4 |
+
2: [231, 180, 22], # yellow - update
|
5 |
+
3: [45, 201, 55] # green - new
|
6 |
+
}
|
7 |
+
|
8 |
+
QUAD_WEIGHTS = {
|
9 |
+
0: 0, # background
|
10 |
+
1: 0.1, # old
|
11 |
+
2: 0.5, # update
|
12 |
+
3: 1 # new
|
13 |
+
}
|
14 |
+
|
15 |
+
VIEWPOINTS = {
|
16 |
+
1: {
|
17 |
+
"azim": [
|
18 |
+
0
|
19 |
+
],
|
20 |
+
"elev": [
|
21 |
+
0
|
22 |
+
],
|
23 |
+
"sector": [
|
24 |
+
"front"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
2: {
|
28 |
+
"azim": [
|
29 |
+
0,
|
30 |
+
30
|
31 |
+
],
|
32 |
+
"elev": [
|
33 |
+
0,
|
34 |
+
0
|
35 |
+
],
|
36 |
+
"sector": [
|
37 |
+
"front",
|
38 |
+
"front"
|
39 |
+
]
|
40 |
+
},
|
41 |
+
4: {
|
42 |
+
"azim": [
|
43 |
+
45,
|
44 |
+
315,
|
45 |
+
135,
|
46 |
+
225,
|
47 |
+
],
|
48 |
+
"elev": [
|
49 |
+
0,
|
50 |
+
0,
|
51 |
+
0,
|
52 |
+
0,
|
53 |
+
],
|
54 |
+
"sector": [
|
55 |
+
"front right",
|
56 |
+
"front left",
|
57 |
+
"back right",
|
58 |
+
"back left",
|
59 |
+
]
|
60 |
+
},
|
61 |
+
6: {
|
62 |
+
"azim": [
|
63 |
+
0,
|
64 |
+
90,
|
65 |
+
270,
|
66 |
+
0,
|
67 |
+
180,
|
68 |
+
0
|
69 |
+
],
|
70 |
+
"elev": [
|
71 |
+
0,
|
72 |
+
0,
|
73 |
+
0,
|
74 |
+
90,
|
75 |
+
0,
|
76 |
+
-90
|
77 |
+
],
|
78 |
+
"sector": [
|
79 |
+
"front",
|
80 |
+
"right",
|
81 |
+
"left",
|
82 |
+
"top",
|
83 |
+
"back",
|
84 |
+
"bottom",
|
85 |
+
]
|
86 |
+
},
|
87 |
+
"shapenet": {
|
88 |
+
"azim": [
|
89 |
+
270,
|
90 |
+
315,
|
91 |
+
225,
|
92 |
+
0,
|
93 |
+
180,
|
94 |
+
45,
|
95 |
+
135,
|
96 |
+
90,
|
97 |
+
270,
|
98 |
+
270
|
99 |
+
],
|
100 |
+
"elev": [
|
101 |
+
15,
|
102 |
+
15,
|
103 |
+
15,
|
104 |
+
15,
|
105 |
+
15,
|
106 |
+
15,
|
107 |
+
15,
|
108 |
+
15,
|
109 |
+
90,
|
110 |
+
-90
|
111 |
+
],
|
112 |
+
"sector": [
|
113 |
+
"front",
|
114 |
+
"front right",
|
115 |
+
"front left",
|
116 |
+
"right",
|
117 |
+
"left",
|
118 |
+
"back right",
|
119 |
+
"back left",
|
120 |
+
"back",
|
121 |
+
"top",
|
122 |
+
"bottom",
|
123 |
+
]
|
124 |
+
},
|
125 |
+
"objaverse": {
|
126 |
+
"azim": [
|
127 |
+
0,
|
128 |
+
45,
|
129 |
+
315,
|
130 |
+
90,
|
131 |
+
270,
|
132 |
+
135,
|
133 |
+
225,
|
134 |
+
180,
|
135 |
+
0,
|
136 |
+
0
|
137 |
+
],
|
138 |
+
"elev": [
|
139 |
+
15,
|
140 |
+
15,
|
141 |
+
15,
|
142 |
+
15,
|
143 |
+
15,
|
144 |
+
15,
|
145 |
+
15,
|
146 |
+
15,
|
147 |
+
90,
|
148 |
+
-90
|
149 |
+
],
|
150 |
+
"sector": [
|
151 |
+
"front",
|
152 |
+
"front right",
|
153 |
+
"front left",
|
154 |
+
"right",
|
155 |
+
"left",
|
156 |
+
"back right",
|
157 |
+
"back left",
|
158 |
+
"back",
|
159 |
+
"top",
|
160 |
+
"bottom",
|
161 |
+
]
|
162 |
+
},
|
163 |
+
12: {
|
164 |
+
"azim": [
|
165 |
+
45,
|
166 |
+
315,
|
167 |
+
135,
|
168 |
+
225,
|
169 |
+
|
170 |
+
0,
|
171 |
+
45,
|
172 |
+
315,
|
173 |
+
90,
|
174 |
+
270,
|
175 |
+
135,
|
176 |
+
225,
|
177 |
+
180,
|
178 |
+
],
|
179 |
+
"elev": [
|
180 |
+
0,
|
181 |
+
0,
|
182 |
+
0,
|
183 |
+
0,
|
184 |
+
|
185 |
+
45,
|
186 |
+
45,
|
187 |
+
45,
|
188 |
+
45,
|
189 |
+
45,
|
190 |
+
45,
|
191 |
+
45,
|
192 |
+
45,
|
193 |
+
],
|
194 |
+
"sector": [
|
195 |
+
"front right",
|
196 |
+
"front left",
|
197 |
+
"back right",
|
198 |
+
"back left",
|
199 |
+
|
200 |
+
"front",
|
201 |
+
"front right",
|
202 |
+
"front left",
|
203 |
+
"right",
|
204 |
+
"left",
|
205 |
+
"back right",
|
206 |
+
"back left",
|
207 |
+
"back",
|
208 |
+
]
|
209 |
+
},
|
210 |
+
20: {
|
211 |
+
"azim": [
|
212 |
+
45,
|
213 |
+
315,
|
214 |
+
135,
|
215 |
+
225,
|
216 |
+
|
217 |
+
0,
|
218 |
+
45,
|
219 |
+
315,
|
220 |
+
90,
|
221 |
+
270,
|
222 |
+
135,
|
223 |
+
225,
|
224 |
+
180,
|
225 |
+
|
226 |
+
0,
|
227 |
+
45,
|
228 |
+
315,
|
229 |
+
90,
|
230 |
+
270,
|
231 |
+
135,
|
232 |
+
225,
|
233 |
+
180,
|
234 |
+
],
|
235 |
+
"elev": [
|
236 |
+
0,
|
237 |
+
0,
|
238 |
+
0,
|
239 |
+
0,
|
240 |
+
|
241 |
+
30,
|
242 |
+
30,
|
243 |
+
30,
|
244 |
+
30,
|
245 |
+
30,
|
246 |
+
30,
|
247 |
+
30,
|
248 |
+
30,
|
249 |
+
|
250 |
+
60,
|
251 |
+
60,
|
252 |
+
60,
|
253 |
+
60,
|
254 |
+
60,
|
255 |
+
60,
|
256 |
+
60,
|
257 |
+
60,
|
258 |
+
],
|
259 |
+
"sector": [
|
260 |
+
"front right",
|
261 |
+
"front left",
|
262 |
+
"back right",
|
263 |
+
"back left",
|
264 |
+
|
265 |
+
"front",
|
266 |
+
"front right",
|
267 |
+
"front left",
|
268 |
+
"right",
|
269 |
+
"left",
|
270 |
+
"back right",
|
271 |
+
"back left",
|
272 |
+
"back",
|
273 |
+
|
274 |
+
"front",
|
275 |
+
"front right",
|
276 |
+
"front left",
|
277 |
+
"right",
|
278 |
+
"left",
|
279 |
+
"back right",
|
280 |
+
"back left",
|
281 |
+
"back",
|
282 |
+
]
|
283 |
+
},
|
284 |
+
36: {
|
285 |
+
"azim": [
|
286 |
+
45,
|
287 |
+
315,
|
288 |
+
135,
|
289 |
+
225,
|
290 |
+
|
291 |
+
0,
|
292 |
+
45,
|
293 |
+
315,
|
294 |
+
90,
|
295 |
+
270,
|
296 |
+
135,
|
297 |
+
225,
|
298 |
+
180,
|
299 |
+
|
300 |
+
0,
|
301 |
+
45,
|
302 |
+
315,
|
303 |
+
90,
|
304 |
+
270,
|
305 |
+
135,
|
306 |
+
225,
|
307 |
+
180,
|
308 |
+
|
309 |
+
22.5,
|
310 |
+
337.5,
|
311 |
+
67.5,
|
312 |
+
292.5,
|
313 |
+
112.5,
|
314 |
+
247.5,
|
315 |
+
157.5,
|
316 |
+
202.5,
|
317 |
+
|
318 |
+
22.5,
|
319 |
+
337.5,
|
320 |
+
67.5,
|
321 |
+
292.5,
|
322 |
+
112.5,
|
323 |
+
247.5,
|
324 |
+
157.5,
|
325 |
+
202.5,
|
326 |
+
],
|
327 |
+
"elev": [
|
328 |
+
0,
|
329 |
+
0,
|
330 |
+
0,
|
331 |
+
0,
|
332 |
+
|
333 |
+
30,
|
334 |
+
30,
|
335 |
+
30,
|
336 |
+
30,
|
337 |
+
30,
|
338 |
+
30,
|
339 |
+
30,
|
340 |
+
30,
|
341 |
+
|
342 |
+
60,
|
343 |
+
60,
|
344 |
+
60,
|
345 |
+
60,
|
346 |
+
60,
|
347 |
+
60,
|
348 |
+
60,
|
349 |
+
60,
|
350 |
+
|
351 |
+
15,
|
352 |
+
15,
|
353 |
+
15,
|
354 |
+
15,
|
355 |
+
15,
|
356 |
+
15,
|
357 |
+
15,
|
358 |
+
15,
|
359 |
+
|
360 |
+
45,
|
361 |
+
45,
|
362 |
+
45,
|
363 |
+
45,
|
364 |
+
45,
|
365 |
+
45,
|
366 |
+
45,
|
367 |
+
45,
|
368 |
+
],
|
369 |
+
"sector": [
|
370 |
+
"front right",
|
371 |
+
"front left",
|
372 |
+
"back right",
|
373 |
+
"back left",
|
374 |
+
|
375 |
+
"front",
|
376 |
+
"front right",
|
377 |
+
"front left",
|
378 |
+
"right",
|
379 |
+
"left",
|
380 |
+
"back right",
|
381 |
+
"back left",
|
382 |
+
"back",
|
383 |
+
|
384 |
+
"top front",
|
385 |
+
"top right",
|
386 |
+
"top left",
|
387 |
+
"top right",
|
388 |
+
"top left",
|
389 |
+
"top right",
|
390 |
+
"top left",
|
391 |
+
"top back",
|
392 |
+
|
393 |
+
"front right",
|
394 |
+
"front left",
|
395 |
+
"front right",
|
396 |
+
"front left",
|
397 |
+
"back right",
|
398 |
+
"back left",
|
399 |
+
"back right",
|
400 |
+
"back left",
|
401 |
+
|
402 |
+
"front right",
|
403 |
+
"front left",
|
404 |
+
"front right",
|
405 |
+
"front left",
|
406 |
+
"back right",
|
407 |
+
"back left",
|
408 |
+
"back right",
|
409 |
+
"back left",
|
410 |
+
]
|
411 |
+
},
|
412 |
+
68: {
|
413 |
+
"azim": [
|
414 |
+
45,
|
415 |
+
315,
|
416 |
+
135,
|
417 |
+
225,
|
418 |
+
|
419 |
+
0,
|
420 |
+
45,
|
421 |
+
315,
|
422 |
+
90,
|
423 |
+
270,
|
424 |
+
135,
|
425 |
+
225,
|
426 |
+
180,
|
427 |
+
|
428 |
+
0,
|
429 |
+
45,
|
430 |
+
315,
|
431 |
+
90,
|
432 |
+
270,
|
433 |
+
135,
|
434 |
+
225,
|
435 |
+
180,
|
436 |
+
|
437 |
+
22.5,
|
438 |
+
337.5,
|
439 |
+
67.5,
|
440 |
+
292.5,
|
441 |
+
112.5,
|
442 |
+
247.5,
|
443 |
+
157.5,
|
444 |
+
202.5,
|
445 |
+
|
446 |
+
22.5,
|
447 |
+
337.5,
|
448 |
+
67.5,
|
449 |
+
292.5,
|
450 |
+
112.5,
|
451 |
+
247.5,
|
452 |
+
157.5,
|
453 |
+
202.5,
|
454 |
+
|
455 |
+
0,
|
456 |
+
45,
|
457 |
+
315,
|
458 |
+
90,
|
459 |
+
270,
|
460 |
+
135,
|
461 |
+
225,
|
462 |
+
180,
|
463 |
+
|
464 |
+
0,
|
465 |
+
45,
|
466 |
+
315,
|
467 |
+
90,
|
468 |
+
270,
|
469 |
+
135,
|
470 |
+
225,
|
471 |
+
180,
|
472 |
+
|
473 |
+
22.5,
|
474 |
+
337.5,
|
475 |
+
67.5,
|
476 |
+
292.5,
|
477 |
+
112.5,
|
478 |
+
247.5,
|
479 |
+
157.5,
|
480 |
+
202.5,
|
481 |
+
|
482 |
+
22.5,
|
483 |
+
337.5,
|
484 |
+
67.5,
|
485 |
+
292.5,
|
486 |
+
112.5,
|
487 |
+
247.5,
|
488 |
+
157.5,
|
489 |
+
202.5
|
490 |
+
],
|
491 |
+
"elev": [
|
492 |
+
0,
|
493 |
+
0,
|
494 |
+
0,
|
495 |
+
0,
|
496 |
+
|
497 |
+
30,
|
498 |
+
30,
|
499 |
+
30,
|
500 |
+
30,
|
501 |
+
30,
|
502 |
+
30,
|
503 |
+
30,
|
504 |
+
30,
|
505 |
+
|
506 |
+
60,
|
507 |
+
60,
|
508 |
+
60,
|
509 |
+
60,
|
510 |
+
60,
|
511 |
+
60,
|
512 |
+
60,
|
513 |
+
60,
|
514 |
+
|
515 |
+
15,
|
516 |
+
15,
|
517 |
+
15,
|
518 |
+
15,
|
519 |
+
15,
|
520 |
+
15,
|
521 |
+
15,
|
522 |
+
15,
|
523 |
+
|
524 |
+
45,
|
525 |
+
45,
|
526 |
+
45,
|
527 |
+
45,
|
528 |
+
45,
|
529 |
+
45,
|
530 |
+
45,
|
531 |
+
45,
|
532 |
+
|
533 |
+
-30,
|
534 |
+
-30,
|
535 |
+
-30,
|
536 |
+
-30,
|
537 |
+
-30,
|
538 |
+
-30,
|
539 |
+
-30,
|
540 |
+
-30,
|
541 |
+
|
542 |
+
-60,
|
543 |
+
-60,
|
544 |
+
-60,
|
545 |
+
-60,
|
546 |
+
-60,
|
547 |
+
-60,
|
548 |
+
-60,
|
549 |
+
-60,
|
550 |
+
|
551 |
+
-15,
|
552 |
+
-15,
|
553 |
+
-15,
|
554 |
+
-15,
|
555 |
+
-15,
|
556 |
+
-15,
|
557 |
+
-15,
|
558 |
+
-15,
|
559 |
+
|
560 |
+
-45,
|
561 |
+
-45,
|
562 |
+
-45,
|
563 |
+
-45,
|
564 |
+
-45,
|
565 |
+
-45,
|
566 |
+
-45,
|
567 |
+
-45,
|
568 |
+
],
|
569 |
+
"sector": [
|
570 |
+
"front right",
|
571 |
+
"front left",
|
572 |
+
"back right",
|
573 |
+
"back left",
|
574 |
+
|
575 |
+
"front",
|
576 |
+
"front right",
|
577 |
+
"front left",
|
578 |
+
"right",
|
579 |
+
"left",
|
580 |
+
"back right",
|
581 |
+
"back left",
|
582 |
+
"back",
|
583 |
+
|
584 |
+
"top front",
|
585 |
+
"top right",
|
586 |
+
"top left",
|
587 |
+
"top right",
|
588 |
+
"top left",
|
589 |
+
"top right",
|
590 |
+
"top left",
|
591 |
+
"top back",
|
592 |
+
|
593 |
+
"front right",
|
594 |
+
"front left",
|
595 |
+
"front right",
|
596 |
+
"front left",
|
597 |
+
"back right",
|
598 |
+
"back left",
|
599 |
+
"back right",
|
600 |
+
"back left",
|
601 |
+
|
602 |
+
"front right",
|
603 |
+
"front left",
|
604 |
+
"front right",
|
605 |
+
"front left",
|
606 |
+
"back right",
|
607 |
+
"back left",
|
608 |
+
"back right",
|
609 |
+
"back left",
|
610 |
+
|
611 |
+
"front",
|
612 |
+
"front right",
|
613 |
+
"front left",
|
614 |
+
"right",
|
615 |
+
"left",
|
616 |
+
"back right",
|
617 |
+
"back left",
|
618 |
+
"back",
|
619 |
+
|
620 |
+
"bottom front",
|
621 |
+
"bottom right",
|
622 |
+
"bottom left",
|
623 |
+
"bottom right",
|
624 |
+
"bottom left",
|
625 |
+
"bottom right",
|
626 |
+
"bottom left",
|
627 |
+
"bottom back",
|
628 |
+
|
629 |
+
"bottom front right",
|
630 |
+
"bottom front left",
|
631 |
+
"bottom front right",
|
632 |
+
"bottom front left",
|
633 |
+
"bottom back right",
|
634 |
+
"bottom back left",
|
635 |
+
"bottom back right",
|
636 |
+
"bottom back left",
|
637 |
+
|
638 |
+
"bottom front right",
|
639 |
+
"bottom front left",
|
640 |
+
"bottom front right",
|
641 |
+
"bottom front left",
|
642 |
+
"bottom back right",
|
643 |
+
"bottom back left",
|
644 |
+
"bottom back right",
|
645 |
+
"bottom back left",
|
646 |
+
]
|
647 |
+
}
|
648 |
+
}
|
text2tex/lib/diffusion_helper.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
from torchvision import transforms
|
8 |
+
|
9 |
+
# Stable Diffusion 2
|
10 |
+
from diffusers import (
|
11 |
+
StableDiffusionInpaintPipeline,
|
12 |
+
StableDiffusionPipeline,
|
13 |
+
EulerDiscreteScheduler
|
14 |
+
)
|
15 |
+
|
16 |
+
# customized
|
17 |
+
import sys
|
18 |
+
sys.path.append(".")
|
19 |
+
|
20 |
+
from models.ControlNet.gradio_depth2image import init_model, process
|
21 |
+
|
22 |
+
|
23 |
+
def get_controlnet_depth():
|
24 |
+
print("=> initializing ControlNet Depth...")
|
25 |
+
model, ddim_sampler = init_model()
|
26 |
+
|
27 |
+
return model, ddim_sampler
|
28 |
+
|
29 |
+
|
30 |
+
def get_inpainting(device):
|
31 |
+
print("=> initializing Inpainting...")
|
32 |
+
|
33 |
+
model = StableDiffusionInpaintPipeline.from_pretrained(
|
34 |
+
"stabilityai/stable-diffusion-2-inpainting",
|
35 |
+
torch_dtype=torch.float16,
|
36 |
+
).to(device)
|
37 |
+
|
38 |
+
return model
|
39 |
+
|
40 |
+
def get_text2image(device):
|
41 |
+
print("=> initializing Inpainting...")
|
42 |
+
|
43 |
+
model_id = "stabilityai/stable-diffusion-2"
|
44 |
+
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
|
45 |
+
model = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=torch.float16).to(device)
|
46 |
+
|
47 |
+
return model
|
48 |
+
|
49 |
+
|
50 |
+
@torch.no_grad()
|
51 |
+
def apply_controlnet_depth(model, ddim_sampler,
|
52 |
+
init_image, prompt, strength, ddim_steps,
|
53 |
+
generate_mask_image, keep_mask_image, depth_map_np,
|
54 |
+
a_prompt, n_prompt, guidance_scale, seed, eta, num_samples,
|
55 |
+
device, blend=0, save_memory=False):
|
56 |
+
"""
|
57 |
+
Use Stable Diffusion 2 to generate image
|
58 |
+
|
59 |
+
Arguments:
|
60 |
+
args: input arguments
|
61 |
+
model: Stable Diffusion 2 model
|
62 |
+
init_image_tensor: input image, torch.FloatTensor of shape (1, H, W, 3)
|
63 |
+
mask_tensor: depth map of the input image, torch.FloatTensor of shape (1, H, W, 1)
|
64 |
+
depth_map_np: depth map of the input image, torch.FloatTensor of shape (1, H, W)
|
65 |
+
"""
|
66 |
+
|
67 |
+
print("=> generating ControlNet Depth RePaint image...")
|
68 |
+
|
69 |
+
|
70 |
+
# Stable Diffusion 2 receives PIL.Image
|
71 |
+
# NOTE Stable Diffusion 2 returns a PIL.Image object
|
72 |
+
# image and mask_image should be PIL images.
|
73 |
+
# The mask structure is white for inpainting and black for keeping as is
|
74 |
+
diffused_image_np = process(
|
75 |
+
model, ddim_sampler,
|
76 |
+
np.array(init_image), prompt, a_prompt, n_prompt, num_samples,
|
77 |
+
ddim_steps, guidance_scale, seed, eta,
|
78 |
+
strength=strength, detected_map=depth_map_np, unknown_mask=np.array(generate_mask_image), save_memory=save_memory
|
79 |
+
)[0]
|
80 |
+
|
81 |
+
init_image = init_image.convert("RGB")
|
82 |
+
diffused_image = Image.fromarray(diffused_image_np).convert("RGB")
|
83 |
+
|
84 |
+
if blend > 0 and transforms.ToTensor()(keep_mask_image).sum() > 0:
|
85 |
+
print("=> blending the generated region...")
|
86 |
+
kernel_size = 3
|
87 |
+
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
88 |
+
|
89 |
+
keep_image_np = np.array(init_image).astype(np.uint8)
|
90 |
+
keep_image_np_dilate = cv2.dilate(keep_image_np, kernel, iterations=1)
|
91 |
+
|
92 |
+
keep_mask_np = np.array(keep_mask_image).astype(np.uint8)
|
93 |
+
keep_mask_np_dilate = cv2.dilate(keep_mask_np, kernel, iterations=1)
|
94 |
+
|
95 |
+
generate_image_np = np.array(diffused_image).astype(np.uint8)
|
96 |
+
|
97 |
+
overlap_mask_np = np.array(generate_mask_image).astype(np.uint8)
|
98 |
+
overlap_mask_np *= keep_mask_np_dilate
|
99 |
+
print("=> blending {} pixels...".format(np.sum(overlap_mask_np)))
|
100 |
+
|
101 |
+
overlap_keep = keep_image_np_dilate[overlap_mask_np == 1]
|
102 |
+
overlap_generate = generate_image_np[overlap_mask_np == 1]
|
103 |
+
|
104 |
+
overlap_np = overlap_keep * blend + overlap_generate * (1 - blend)
|
105 |
+
|
106 |
+
generate_image_np[overlap_mask_np == 1] = overlap_np
|
107 |
+
|
108 |
+
diffused_image = Image.fromarray(generate_image_np.astype(np.uint8)).convert("RGB")
|
109 |
+
|
110 |
+
init_image_masked = init_image
|
111 |
+
diffused_image_masked = diffused_image
|
112 |
+
|
113 |
+
return diffused_image, init_image_masked, diffused_image_masked
|
114 |
+
|
115 |
+
|
116 |
+
@torch.no_grad()
|
117 |
+
def apply_inpainting(model,
|
118 |
+
init_image, mask_image_tensor, prompt, height, width, device):
|
119 |
+
"""
|
120 |
+
Use Stable Diffusion 2 to generate image
|
121 |
+
|
122 |
+
Arguments:
|
123 |
+
args: input arguments
|
124 |
+
model: Stable Diffusion 2 model
|
125 |
+
init_image_tensor: input image, torch.FloatTensor of shape (1, H, W, 3)
|
126 |
+
mask_tensor: depth map of the input image, torch.FloatTensor of shape (1, H, W, 1)
|
127 |
+
depth_map_tensor: depth map of the input image, torch.FloatTensor of shape (1, H, W)
|
128 |
+
"""
|
129 |
+
|
130 |
+
print("=> generating Inpainting image...")
|
131 |
+
|
132 |
+
mask_image = mask_image_tensor[0].cpu()
|
133 |
+
mask_image = mask_image.permute(2, 0, 1)
|
134 |
+
mask_image = transforms.ToPILImage()(mask_image).convert("L")
|
135 |
+
|
136 |
+
# NOTE Stable Diffusion 2 returns a PIL.Image object
|
137 |
+
# image and mask_image should be PIL images.
|
138 |
+
# The mask structure is white for inpainting and black for keeping as is
|
139 |
+
diffused_image = model(
|
140 |
+
prompt=prompt,
|
141 |
+
image=init_image.resize((512, 512)),
|
142 |
+
mask_image=mask_image.resize((512, 512)),
|
143 |
+
height=512,
|
144 |
+
width=512
|
145 |
+
).images[0].resize((height, width))
|
146 |
+
|
147 |
+
return diffused_image
|
148 |
+
|
149 |
+
|
150 |
+
@torch.no_grad()
|
151 |
+
def apply_inpainting_postprocess(model,
|
152 |
+
init_image, mask_image_tensor, prompt, height, width, device):
|
153 |
+
"""
|
154 |
+
Use Stable Diffusion 2 to generate image
|
155 |
+
|
156 |
+
Arguments:
|
157 |
+
args: input arguments
|
158 |
+
model: Stable Diffusion 2 model
|
159 |
+
init_image_tensor: input image, torch.FloatTensor of shape (1, H, W, 3)
|
160 |
+
mask_tensor: depth map of the input image, torch.FloatTensor of shape (1, H, W, 1)
|
161 |
+
depth_map_tensor: depth map of the input image, torch.FloatTensor of shape (1, H, W)
|
162 |
+
"""
|
163 |
+
|
164 |
+
print("=> generating Inpainting image...")
|
165 |
+
|
166 |
+
mask_image = mask_image_tensor[0].cpu()
|
167 |
+
mask_image = mask_image.permute(2, 0, 1)
|
168 |
+
mask_image = transforms.ToPILImage()(mask_image).convert("L")
|
169 |
+
|
170 |
+
# NOTE Stable Diffusion 2 returns a PIL.Image object
|
171 |
+
# image and mask_image should be PIL images.
|
172 |
+
# The mask structure is white for inpainting and black for keeping as is
|
173 |
+
diffused_image = model(
|
174 |
+
prompt=prompt,
|
175 |
+
image=init_image.resize((512, 512)),
|
176 |
+
mask_image=mask_image.resize((512, 512)),
|
177 |
+
height=512,
|
178 |
+
width=512
|
179 |
+
).images[0].resize((height, width))
|
180 |
+
|
181 |
+
diffused_image_tensor = torch.from_numpy(np.array(diffused_image)).to(device)
|
182 |
+
|
183 |
+
init_images_tensor = torch.from_numpy(np.array(init_image)).to(device)
|
184 |
+
|
185 |
+
init_images_tensor = diffused_image_tensor * mask_image_tensor[0] + init_images_tensor * (1 - mask_image_tensor[0])
|
186 |
+
init_image = Image.fromarray(init_images_tensor.cpu().numpy().astype(np.uint8)).convert("RGB")
|
187 |
+
|
188 |
+
return init_image
|
189 |
+
|
text2tex/lib/io_helper.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# common utils
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
|
5 |
+
# numpy
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
# visualization
|
9 |
+
import matplotlib
|
10 |
+
import matplotlib.cm as cm
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
|
13 |
+
matplotlib.use("Agg")
|
14 |
+
|
15 |
+
from pytorch3d.io import save_obj
|
16 |
+
|
17 |
+
from torchvision import transforms
|
18 |
+
|
19 |
+
|
20 |
+
def save_depth(fragments, output_dir, init_image, view_idx):
|
21 |
+
print("=> saving depth...")
|
22 |
+
width, height = init_image.size
|
23 |
+
dpi = 100
|
24 |
+
figsize = width / float(dpi), height / float(dpi)
|
25 |
+
|
26 |
+
depth_np = fragments.zbuf[0].cpu().numpy()
|
27 |
+
|
28 |
+
fig = plt.figure(figsize=figsize)
|
29 |
+
ax = fig.add_axes([0, 0, 1, 1])
|
30 |
+
# Hide spines, ticks, etc.
|
31 |
+
ax.axis('off')
|
32 |
+
# Display the image.
|
33 |
+
ax.imshow(depth_np, cmap='gray')
|
34 |
+
|
35 |
+
plt.savefig(os.path.join(output_dir, "{}.png".format(view_idx)), bbox_inches='tight', pad_inches=0)
|
36 |
+
np.save(os.path.join(output_dir, "{}.npy".format(view_idx)), depth_np[..., 0])
|
37 |
+
|
38 |
+
|
39 |
+
def save_backproject_obj(output_dir, obj_name,
|
40 |
+
verts, faces, verts_uvs, faces_uvs, projected_texture,
|
41 |
+
device):
|
42 |
+
print("=> saving OBJ file...")
|
43 |
+
texture_map = transforms.ToTensor()(projected_texture).to(device)
|
44 |
+
texture_map = texture_map.permute(1, 2, 0)
|
45 |
+
obj_path = os.path.join(output_dir, obj_name)
|
46 |
+
|
47 |
+
save_obj(
|
48 |
+
obj_path,
|
49 |
+
verts=verts,
|
50 |
+
faces=faces,
|
51 |
+
decimal_places=5,
|
52 |
+
verts_uvs=verts_uvs,
|
53 |
+
faces_uvs=faces_uvs,
|
54 |
+
texture_map=texture_map
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
def save_args(args, output_dir):
|
59 |
+
with open(os.path.join(output_dir, "args.json"), "w") as f:
|
60 |
+
json.dump(
|
61 |
+
{k: v for k, v in vars(args).items()},
|
62 |
+
f,
|
63 |
+
indent=4
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
def save_viewpoints(args, output_dir, dist_list, elev_list, azim_list, view_list):
|
68 |
+
with open(os.path.join(output_dir, "viewpoints.json"), "w") as f:
|
69 |
+
json.dump(
|
70 |
+
{
|
71 |
+
"dist": dist_list,
|
72 |
+
"elev": elev_list,
|
73 |
+
"azim": azim_list,
|
74 |
+
"view": view_list
|
75 |
+
},
|
76 |
+
f,
|
77 |
+
indent=4
|
78 |
+
)
|
text2tex/lib/mesh_helper.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import trimesh
|
4 |
+
import xatlas
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from sklearn.decomposition import PCA
|
9 |
+
|
10 |
+
from torchvision import transforms
|
11 |
+
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from pytorch3d.io import (
|
15 |
+
load_obj,
|
16 |
+
load_objs_as_meshes
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
def compute_principle_directions(model_path, num_points=20000):
|
21 |
+
mesh = trimesh.load_mesh(model_path, force="mesh")
|
22 |
+
pc, _ = trimesh.sample.sample_surface_even(mesh, num_points)
|
23 |
+
|
24 |
+
pc -= np.mean(pc, axis=0, keepdims=True)
|
25 |
+
|
26 |
+
principle_directions = PCA(n_components=3).fit(pc).components_
|
27 |
+
|
28 |
+
return principle_directions
|
29 |
+
|
30 |
+
|
31 |
+
def init_mesh(input_path, cache_path, device):
|
32 |
+
print("=> parameterizing target mesh...")
|
33 |
+
|
34 |
+
mesh = trimesh.load_mesh(input_path, force='mesh')
|
35 |
+
try:
|
36 |
+
vertices, faces = mesh.vertices, mesh.faces
|
37 |
+
except AttributeError:
|
38 |
+
print("multiple materials in {} are not supported".format(input_path))
|
39 |
+
exit()
|
40 |
+
|
41 |
+
vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
|
42 |
+
xatlas.export(str(cache_path), vertices[vmapping], indices, uvs)
|
43 |
+
|
44 |
+
print("=> loading target mesh...")
|
45 |
+
|
46 |
+
# principle_directions = compute_principle_directions(cache_path)
|
47 |
+
principle_directions = None
|
48 |
+
|
49 |
+
_, faces, aux = load_obj(cache_path, device=device)
|
50 |
+
mesh = load_objs_as_meshes([cache_path], device=device)
|
51 |
+
|
52 |
+
num_verts = mesh.verts_packed().shape[0]
|
53 |
+
|
54 |
+
# make sure mesh center is at origin
|
55 |
+
bbox = mesh.get_bounding_boxes()
|
56 |
+
mesh_center = bbox.mean(dim=2).repeat(num_verts, 1)
|
57 |
+
mesh = apply_offsets_to_mesh(mesh, -mesh_center)
|
58 |
+
|
59 |
+
# make sure mesh size is normalized
|
60 |
+
box_size = bbox[..., 1] - bbox[..., 0]
|
61 |
+
box_max = box_size.max(dim=1, keepdim=True)[0].repeat(num_verts, 3)
|
62 |
+
mesh = apply_scale_to_mesh(mesh, 1 / box_max)
|
63 |
+
|
64 |
+
return mesh, mesh.verts_packed(), faces, aux, principle_directions, mesh_center, box_max
|
65 |
+
|
66 |
+
|
67 |
+
def apply_offsets_to_mesh(mesh, offsets):
|
68 |
+
new_mesh = mesh.offset_verts(offsets)
|
69 |
+
|
70 |
+
return new_mesh
|
71 |
+
|
72 |
+
def apply_scale_to_mesh(mesh, scale):
|
73 |
+
new_mesh = mesh.scale_verts(scale)
|
74 |
+
|
75 |
+
return new_mesh
|
76 |
+
|
77 |
+
|
78 |
+
def adjust_uv_map(faces, aux, init_texture, uv_size):
|
79 |
+
"""
|
80 |
+
adjust UV map to be compatiable with multiple textures.
|
81 |
+
UVs for different materials will be decomposed and placed horizontally
|
82 |
+
|
83 |
+
+-----+-----+-----+--
|
84 |
+
| 1 | 2 | 3 |
|
85 |
+
+-----+-----+-----+--
|
86 |
+
|
87 |
+
"""
|
88 |
+
|
89 |
+
textures_ids = faces.textures_idx
|
90 |
+
materials_idx = faces.materials_idx
|
91 |
+
verts_uvs = aux.verts_uvs
|
92 |
+
|
93 |
+
num_materials = torch.unique(materials_idx).shape[0]
|
94 |
+
|
95 |
+
new_verts_uvs = verts_uvs.clone()
|
96 |
+
for material_id in range(num_materials):
|
97 |
+
# apply offsets to horizontal axis
|
98 |
+
faces_ids = textures_ids[materials_idx == material_id].unique()
|
99 |
+
new_verts_uvs[faces_ids, 0] += material_id
|
100 |
+
|
101 |
+
new_verts_uvs[:, 0] /= num_materials
|
102 |
+
|
103 |
+
init_texture_tensor = transforms.ToTensor()(init_texture)
|
104 |
+
init_texture_tensor = torch.cat([init_texture_tensor for _ in range(num_materials)], dim=-1)
|
105 |
+
init_texture = transforms.ToPILImage()(init_texture_tensor).resize((uv_size, uv_size))
|
106 |
+
|
107 |
+
return new_verts_uvs, init_texture
|
108 |
+
|
109 |
+
|
110 |
+
@torch.no_grad()
|
111 |
+
def update_face_angles(mesh, cameras, fragments):
|
112 |
+
def get_angle(x, y):
|
113 |
+
x = torch.nn.functional.normalize(x)
|
114 |
+
y = torch.nn.functional.normalize(y)
|
115 |
+
inner_product = (x * y).sum(dim=1)
|
116 |
+
x_norm = x.pow(2).sum(dim=1).pow(0.5)
|
117 |
+
y_norm = y.pow(2).sum(dim=1).pow(0.5)
|
118 |
+
cos = inner_product / (x_norm * y_norm)
|
119 |
+
angle = torch.acos(cos)
|
120 |
+
angle = angle * 180 / 3.14159
|
121 |
+
|
122 |
+
return angle
|
123 |
+
|
124 |
+
# face normals
|
125 |
+
face_normals = mesh.faces_normals_padded()[0]
|
126 |
+
|
127 |
+
# view vector (object center -> camera center)
|
128 |
+
camera_center = cameras.get_camera_center()
|
129 |
+
|
130 |
+
face_angles = get_angle(
|
131 |
+
face_normals,
|
132 |
+
camera_center.repeat(face_normals.shape[0], 1)
|
133 |
+
) # (F)
|
134 |
+
|
135 |
+
face_angles_rev = get_angle(
|
136 |
+
face_normals,
|
137 |
+
-camera_center.repeat(face_normals.shape[0], 1)
|
138 |
+
) # (F)
|
139 |
+
|
140 |
+
face_angles = torch.minimum(face_angles, face_angles_rev)
|
141 |
+
|
142 |
+
# Indices of unique visible faces
|
143 |
+
visible_map = fragments.pix_to_face.unique() # (num_visible_faces)
|
144 |
+
invisible_mask = torch.ones_like(face_angles)
|
145 |
+
invisible_mask[visible_map] = 0
|
146 |
+
face_angles[invisible_mask == 1] = 10000. # angles of invisible faces are ignored
|
147 |
+
|
148 |
+
return face_angles
|
text2tex/lib/projection_helper.py
ADDED
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import random
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from torchvision import transforms
|
10 |
+
|
11 |
+
from pytorch3d.renderer import TexturesUV
|
12 |
+
from pytorch3d.ops import interpolate_face_attributes
|
13 |
+
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
# customized
|
19 |
+
import sys
|
20 |
+
sys.path.append(".")
|
21 |
+
|
22 |
+
from lib.camera_helper import init_camera
|
23 |
+
from lib.render_helper import init_renderer, render
|
24 |
+
from lib.shading_helper import (
|
25 |
+
BlendParams,
|
26 |
+
init_soft_phong_shader,
|
27 |
+
init_flat_texel_shader,
|
28 |
+
)
|
29 |
+
from lib.vis_helper import visualize_outputs, visualize_quad_mask
|
30 |
+
from lib.constants import *
|
31 |
+
|
32 |
+
|
33 |
+
def get_all_4_locations(values_y, values_x):
|
34 |
+
y_0 = torch.floor(values_y)
|
35 |
+
y_1 = torch.ceil(values_y)
|
36 |
+
x_0 = torch.floor(values_x)
|
37 |
+
x_1 = torch.ceil(values_x)
|
38 |
+
|
39 |
+
return torch.cat([y_0, y_0, y_1, y_1], 0).long(), torch.cat([x_0, x_1, x_0, x_1], 0).long()
|
40 |
+
|
41 |
+
|
42 |
+
def compose_quad_mask(new_mask_image, update_mask_image, old_mask_image, device):
|
43 |
+
"""
|
44 |
+
compose quad mask:
|
45 |
+
-> 0: background
|
46 |
+
-> 1: old
|
47 |
+
-> 2: update
|
48 |
+
-> 3: new
|
49 |
+
"""
|
50 |
+
|
51 |
+
new_mask_tensor = transforms.ToTensor()(new_mask_image).to(device)
|
52 |
+
update_mask_tensor = transforms.ToTensor()(update_mask_image).to(device)
|
53 |
+
old_mask_tensor = transforms.ToTensor()(old_mask_image).to(device)
|
54 |
+
|
55 |
+
all_mask_tensor = new_mask_tensor + update_mask_tensor + old_mask_tensor
|
56 |
+
|
57 |
+
quad_mask_tensor = torch.zeros_like(all_mask_tensor)
|
58 |
+
quad_mask_tensor[old_mask_tensor == 1] = 1
|
59 |
+
quad_mask_tensor[update_mask_tensor == 1] = 2
|
60 |
+
quad_mask_tensor[new_mask_tensor == 1] = 3
|
61 |
+
|
62 |
+
return old_mask_tensor, update_mask_tensor, new_mask_tensor, all_mask_tensor, quad_mask_tensor
|
63 |
+
|
64 |
+
|
65 |
+
def compute_view_heat(similarity_tensor, quad_mask_tensor):
|
66 |
+
num_total_pixels = quad_mask_tensor.reshape(-1).shape[0]
|
67 |
+
heat = 0
|
68 |
+
for idx in QUAD_WEIGHTS:
|
69 |
+
heat += (quad_mask_tensor == idx).sum() * QUAD_WEIGHTS[idx] / num_total_pixels
|
70 |
+
|
71 |
+
return heat
|
72 |
+
|
73 |
+
|
74 |
+
def select_viewpoint(selected_view_ids, view_punishments,
|
75 |
+
mode, dist_list, elev_list, azim_list, sector_list, view_idx,
|
76 |
+
similarity_texture_cache, exist_texture,
|
77 |
+
mesh, faces, verts_uvs,
|
78 |
+
image_size, faces_per_pixel,
|
79 |
+
init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir,
|
80 |
+
device, use_principle=False
|
81 |
+
):
|
82 |
+
if mode == "sequential":
|
83 |
+
|
84 |
+
num_views = len(dist_list)
|
85 |
+
|
86 |
+
dist = dist_list[view_idx % num_views]
|
87 |
+
elev = elev_list[view_idx % num_views]
|
88 |
+
azim = azim_list[view_idx % num_views]
|
89 |
+
sector = sector_list[view_idx % num_views]
|
90 |
+
|
91 |
+
selected_view_ids.append(view_idx % num_views)
|
92 |
+
|
93 |
+
elif mode == "heuristic":
|
94 |
+
|
95 |
+
if use_principle and view_idx < 6:
|
96 |
+
|
97 |
+
selected_view_idx = view_idx
|
98 |
+
|
99 |
+
else:
|
100 |
+
|
101 |
+
selected_view_idx = None
|
102 |
+
max_heat = 0
|
103 |
+
|
104 |
+
print("=> selecting next view...")
|
105 |
+
view_heat_list = []
|
106 |
+
for sample_idx in tqdm(range(len(dist_list))):
|
107 |
+
|
108 |
+
view_heat, *_ = render_one_view_and_build_masks(dist_list[sample_idx], elev_list[sample_idx], azim_list[sample_idx],
|
109 |
+
sample_idx, sample_idx, view_punishments,
|
110 |
+
similarity_texture_cache, exist_texture,
|
111 |
+
mesh, faces, verts_uvs,
|
112 |
+
image_size, faces_per_pixel,
|
113 |
+
init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir,
|
114 |
+
device)
|
115 |
+
|
116 |
+
if view_heat > max_heat:
|
117 |
+
selected_view_idx = sample_idx
|
118 |
+
max_heat = view_heat
|
119 |
+
|
120 |
+
view_heat_list.append(view_heat.item())
|
121 |
+
|
122 |
+
print(view_heat_list)
|
123 |
+
print("select view {} with heat {}".format(selected_view_idx, max_heat))
|
124 |
+
|
125 |
+
|
126 |
+
dist = dist_list[selected_view_idx]
|
127 |
+
elev = elev_list[selected_view_idx]
|
128 |
+
azim = azim_list[selected_view_idx]
|
129 |
+
sector = sector_list[selected_view_idx]
|
130 |
+
|
131 |
+
selected_view_ids.append(selected_view_idx)
|
132 |
+
|
133 |
+
view_punishments[selected_view_idx] *= 0.01
|
134 |
+
|
135 |
+
elif mode == "random":
|
136 |
+
|
137 |
+
selected_view_idx = random.choice(range(len(dist_list)))
|
138 |
+
|
139 |
+
dist = dist_list[selected_view_idx]
|
140 |
+
elev = elev_list[selected_view_idx]
|
141 |
+
azim = azim_list[selected_view_idx]
|
142 |
+
sector = sector_list[selected_view_idx]
|
143 |
+
|
144 |
+
selected_view_ids.append(selected_view_idx)
|
145 |
+
|
146 |
+
else:
|
147 |
+
raise NotImplementedError()
|
148 |
+
|
149 |
+
return dist, elev, azim, sector, selected_view_ids, view_punishments
|
150 |
+
|
151 |
+
|
152 |
+
@torch.no_grad()
|
153 |
+
def build_backproject_mask(mesh, faces, verts_uvs,
|
154 |
+
cameras, reference_image, faces_per_pixel,
|
155 |
+
image_size, uv_size, device):
|
156 |
+
# construct pixel UVs
|
157 |
+
renderer_scaled = init_renderer(cameras,
|
158 |
+
shader=init_soft_phong_shader(
|
159 |
+
camera=cameras,
|
160 |
+
blend_params=BlendParams(),
|
161 |
+
device=device),
|
162 |
+
image_size=image_size,
|
163 |
+
faces_per_pixel=faces_per_pixel
|
164 |
+
)
|
165 |
+
fragments_scaled = renderer_scaled.rasterizer(mesh)
|
166 |
+
|
167 |
+
# get UV coordinates for each pixel
|
168 |
+
faces_verts_uvs = verts_uvs[faces.textures_idx]
|
169 |
+
|
170 |
+
pixel_uvs = interpolate_face_attributes(
|
171 |
+
fragments_scaled.pix_to_face, fragments_scaled.bary_coords, faces_verts_uvs
|
172 |
+
) # NxHsxWsxKx2
|
173 |
+
pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(-1, 2)
|
174 |
+
|
175 |
+
texture_locations_y, texture_locations_x = get_all_4_locations(
|
176 |
+
(1 - pixel_uvs[:, 1]).reshape(-1) * (uv_size - 1),
|
177 |
+
pixel_uvs[:, 0].reshape(-1) * (uv_size - 1)
|
178 |
+
)
|
179 |
+
|
180 |
+
K = faces_per_pixel
|
181 |
+
|
182 |
+
texture_values = torch.from_numpy(np.array(reference_image.resize((image_size, image_size)))).float() / 255.
|
183 |
+
texture_values = texture_values.to(device).unsqueeze(0).expand([4, -1, -1, -1]).unsqueeze(0).expand([K, -1, -1, -1, -1])
|
184 |
+
|
185 |
+
# texture
|
186 |
+
texture_tensor = torch.zeros(uv_size, uv_size, 3).to(device)
|
187 |
+
texture_tensor[texture_locations_y, texture_locations_x, :] = texture_values.reshape(-1, 3)
|
188 |
+
|
189 |
+
return texture_tensor[:, :, 0]
|
190 |
+
|
191 |
+
|
192 |
+
@torch.no_grad()
|
193 |
+
def build_diffusion_mask(mesh_stuff,
|
194 |
+
renderer, exist_texture, similarity_texture_cache, target_value, device, image_size,
|
195 |
+
smooth_mask=False, view_threshold=0.01):
|
196 |
+
|
197 |
+
mesh, faces, verts_uvs = mesh_stuff
|
198 |
+
mask_mesh = mesh.clone() # NOTE in-place operation - DANGER!!!
|
199 |
+
|
200 |
+
# visible mask => the whole region
|
201 |
+
exist_texture_expand = exist_texture.unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3).to(device)
|
202 |
+
mask_mesh.textures = TexturesUV(
|
203 |
+
maps=torch.ones_like(exist_texture_expand),
|
204 |
+
faces_uvs=faces.textures_idx[None, ...],
|
205 |
+
verts_uvs=verts_uvs[None, ...],
|
206 |
+
sampling_mode="nearest"
|
207 |
+
)
|
208 |
+
# visible_mask_tensor, *_ = render(mask_mesh, renderer)
|
209 |
+
visible_mask_tensor, _, similarity_map_tensor, *_ = render(mask_mesh, renderer)
|
210 |
+
# faces that are too rotated away from the viewpoint will be treated as invisible
|
211 |
+
valid_mask_tensor = (similarity_map_tensor >= view_threshold).float()
|
212 |
+
visible_mask_tensor *= valid_mask_tensor
|
213 |
+
|
214 |
+
# nonexist mask <=> new mask
|
215 |
+
exist_texture_expand = exist_texture.unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3).to(device)
|
216 |
+
mask_mesh.textures = TexturesUV(
|
217 |
+
maps=1 - exist_texture_expand,
|
218 |
+
faces_uvs=faces.textures_idx[None, ...],
|
219 |
+
verts_uvs=verts_uvs[None, ...],
|
220 |
+
sampling_mode="nearest"
|
221 |
+
)
|
222 |
+
new_mask_tensor, *_ = render(mask_mesh, renderer)
|
223 |
+
new_mask_tensor *= valid_mask_tensor
|
224 |
+
|
225 |
+
# exist mask => visible mask - new mask
|
226 |
+
exist_mask_tensor = visible_mask_tensor - new_mask_tensor
|
227 |
+
exist_mask_tensor[exist_mask_tensor < 0] = 0 # NOTE dilate can lead to overflow
|
228 |
+
|
229 |
+
# all update mask
|
230 |
+
mask_mesh.textures = TexturesUV(
|
231 |
+
maps=(
|
232 |
+
similarity_texture_cache.argmax(0) == target_value
|
233 |
+
# # only consider the views that have already appeared before
|
234 |
+
# similarity_texture_cache[0:target_value+1].argmax(0) == target_value
|
235 |
+
).float().unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3).to(device),
|
236 |
+
faces_uvs=faces.textures_idx[None, ...],
|
237 |
+
verts_uvs=verts_uvs[None, ...],
|
238 |
+
sampling_mode="nearest"
|
239 |
+
)
|
240 |
+
all_update_mask_tensor, *_ = render(mask_mesh, renderer)
|
241 |
+
|
242 |
+
# current update mask => intersection between all update mask and exist mask
|
243 |
+
update_mask_tensor = exist_mask_tensor * all_update_mask_tensor
|
244 |
+
|
245 |
+
# keep mask => exist mask - update mask
|
246 |
+
old_mask_tensor = exist_mask_tensor - update_mask_tensor
|
247 |
+
|
248 |
+
# convert
|
249 |
+
new_mask = new_mask_tensor[0].cpu().float().permute(2, 0, 1)
|
250 |
+
new_mask = transforms.ToPILImage()(new_mask).convert("L")
|
251 |
+
|
252 |
+
update_mask = update_mask_tensor[0].cpu().float().permute(2, 0, 1)
|
253 |
+
update_mask = transforms.ToPILImage()(update_mask).convert("L")
|
254 |
+
|
255 |
+
old_mask = old_mask_tensor[0].cpu().float().permute(2, 0, 1)
|
256 |
+
old_mask = transforms.ToPILImage()(old_mask).convert("L")
|
257 |
+
|
258 |
+
exist_mask = exist_mask_tensor[0].cpu().float().permute(2, 0, 1)
|
259 |
+
exist_mask = transforms.ToPILImage()(exist_mask).convert("L")
|
260 |
+
|
261 |
+
return new_mask, update_mask, old_mask, exist_mask
|
262 |
+
|
263 |
+
|
264 |
+
@torch.no_grad()
|
265 |
+
def render_one_view(mesh,
|
266 |
+
dist, elev, azim,
|
267 |
+
image_size, faces_per_pixel,
|
268 |
+
device):
|
269 |
+
|
270 |
+
# render the view
|
271 |
+
cameras = init_camera(
|
272 |
+
dist, elev, azim,
|
273 |
+
image_size, device
|
274 |
+
)
|
275 |
+
renderer = init_renderer(cameras,
|
276 |
+
shader=init_soft_phong_shader(
|
277 |
+
camera=cameras,
|
278 |
+
blend_params=BlendParams(),
|
279 |
+
device=device),
|
280 |
+
image_size=image_size,
|
281 |
+
faces_per_pixel=faces_per_pixel
|
282 |
+
)
|
283 |
+
|
284 |
+
init_images_tensor, normal_maps_tensor, similarity_tensor, depth_maps_tensor, fragments = render(mesh, renderer)
|
285 |
+
|
286 |
+
return (
|
287 |
+
cameras, renderer,
|
288 |
+
init_images_tensor, normal_maps_tensor, similarity_tensor, depth_maps_tensor, fragments
|
289 |
+
)
|
290 |
+
|
291 |
+
|
292 |
+
@torch.no_grad()
|
293 |
+
def build_similarity_texture_cache_for_all_views(mesh, faces, verts_uvs,
|
294 |
+
dist_list, elev_list, azim_list,
|
295 |
+
image_size, image_size_scaled, uv_size, faces_per_pixel,
|
296 |
+
device):
|
297 |
+
|
298 |
+
num_candidate_views = len(dist_list)
|
299 |
+
similarity_texture_cache = torch.zeros(num_candidate_views, uv_size, uv_size).to(device)
|
300 |
+
|
301 |
+
print("=> building similarity texture cache for all views...")
|
302 |
+
for i in tqdm(range(num_candidate_views)):
|
303 |
+
cameras, _, _, _, similarity_tensor, _, _ = render_one_view(mesh,
|
304 |
+
dist_list[i], elev_list[i], azim_list[i],
|
305 |
+
image_size, faces_per_pixel, device)
|
306 |
+
|
307 |
+
similarity_texture_cache[i] = build_backproject_mask(mesh, faces, verts_uvs,
|
308 |
+
cameras, transforms.ToPILImage()(similarity_tensor[0, :, :, 0]).convert("RGB"), faces_per_pixel,
|
309 |
+
image_size_scaled, uv_size, device)
|
310 |
+
|
311 |
+
return similarity_texture_cache
|
312 |
+
|
313 |
+
|
314 |
+
@torch.no_grad()
|
315 |
+
def render_one_view_and_build_masks(dist, elev, azim,
|
316 |
+
selected_view_idx, view_idx, view_punishments,
|
317 |
+
similarity_texture_cache, exist_texture,
|
318 |
+
mesh, faces, verts_uvs,
|
319 |
+
image_size, faces_per_pixel,
|
320 |
+
init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir,
|
321 |
+
device, save_intermediate=False, smooth_mask=False, view_threshold=0.01):
|
322 |
+
|
323 |
+
# render the view
|
324 |
+
(
|
325 |
+
cameras, renderer,
|
326 |
+
init_images_tensor, normal_maps_tensor, similarity_tensor, depth_maps_tensor, fragments
|
327 |
+
) = render_one_view(mesh,
|
328 |
+
dist, elev, azim,
|
329 |
+
image_size, faces_per_pixel,
|
330 |
+
device
|
331 |
+
)
|
332 |
+
|
333 |
+
init_image = init_images_tensor[0].cpu()
|
334 |
+
init_image = init_image.permute(2, 0, 1)
|
335 |
+
init_image = transforms.ToPILImage()(init_image).convert("RGB")
|
336 |
+
|
337 |
+
normal_map = normal_maps_tensor[0].cpu()
|
338 |
+
normal_map = normal_map.permute(2, 0, 1)
|
339 |
+
normal_map = transforms.ToPILImage()(normal_map).convert("RGB")
|
340 |
+
|
341 |
+
depth_map = depth_maps_tensor[0].cpu().numpy()
|
342 |
+
depth_map = Image.fromarray(depth_map).convert("L")
|
343 |
+
|
344 |
+
similarity_map = similarity_tensor[0, :, :, 0].cpu()
|
345 |
+
similarity_map = transforms.ToPILImage()(similarity_map).convert("L")
|
346 |
+
|
347 |
+
|
348 |
+
flat_renderer = init_renderer(cameras,
|
349 |
+
shader=init_flat_texel_shader(
|
350 |
+
camera=cameras,
|
351 |
+
device=device),
|
352 |
+
image_size=image_size,
|
353 |
+
faces_per_pixel=faces_per_pixel
|
354 |
+
)
|
355 |
+
new_mask_image, update_mask_image, old_mask_image, exist_mask_image = build_diffusion_mask(
|
356 |
+
(mesh, faces, verts_uvs),
|
357 |
+
flat_renderer, exist_texture, similarity_texture_cache, selected_view_idx, device, image_size,
|
358 |
+
smooth_mask=smooth_mask, view_threshold=view_threshold
|
359 |
+
)
|
360 |
+
# NOTE the view idx is the absolute idx in the sample space (i.e. `selected_view_idx`)
|
361 |
+
# it should match with `similarity_texture_cache`
|
362 |
+
|
363 |
+
(
|
364 |
+
old_mask_tensor,
|
365 |
+
update_mask_tensor,
|
366 |
+
new_mask_tensor,
|
367 |
+
all_mask_tensor,
|
368 |
+
quad_mask_tensor
|
369 |
+
) = compose_quad_mask(new_mask_image, update_mask_image, old_mask_image, device)
|
370 |
+
|
371 |
+
view_heat = compute_view_heat(similarity_tensor, quad_mask_tensor)
|
372 |
+
view_heat *= view_punishments[selected_view_idx]
|
373 |
+
|
374 |
+
# save intermediate results
|
375 |
+
if save_intermediate:
|
376 |
+
init_image.save(os.path.join(init_image_dir, "{}.png".format(view_idx)))
|
377 |
+
normal_map.save(os.path.join(normal_map_dir, "{}.png".format(view_idx)))
|
378 |
+
depth_map.save(os.path.join(depth_map_dir, "{}.png".format(view_idx)))
|
379 |
+
similarity_map.save(os.path.join(similarity_map_dir, "{}.png".format(view_idx)))
|
380 |
+
|
381 |
+
new_mask_image.save(os.path.join(mask_image_dir, "{}_new.png".format(view_idx)))
|
382 |
+
update_mask_image.save(os.path.join(mask_image_dir, "{}_update.png".format(view_idx)))
|
383 |
+
old_mask_image.save(os.path.join(mask_image_dir, "{}_old.png".format(view_idx)))
|
384 |
+
exist_mask_image.save(os.path.join(mask_image_dir, "{}_exist.png".format(view_idx)))
|
385 |
+
|
386 |
+
visualize_quad_mask(mask_image_dir, quad_mask_tensor, view_idx, view_heat, device)
|
387 |
+
|
388 |
+
return (
|
389 |
+
view_heat,
|
390 |
+
renderer, cameras, fragments,
|
391 |
+
init_image, normal_map, depth_map,
|
392 |
+
init_images_tensor, normal_maps_tensor, depth_maps_tensor, similarity_tensor,
|
393 |
+
old_mask_image, update_mask_image, new_mask_image,
|
394 |
+
old_mask_tensor, update_mask_tensor, new_mask_tensor, all_mask_tensor, quad_mask_tensor
|
395 |
+
)
|
396 |
+
|
397 |
+
|
398 |
+
|
399 |
+
@torch.no_grad()
|
400 |
+
def backproject_from_image(mesh, faces, verts_uvs, cameras,
|
401 |
+
reference_image, new_mask_image, update_mask_image,
|
402 |
+
init_texture, exist_texture,
|
403 |
+
image_size, uv_size, faces_per_pixel,
|
404 |
+
device):
|
405 |
+
|
406 |
+
# construct pixel UVs
|
407 |
+
renderer_scaled = init_renderer(cameras,
|
408 |
+
shader=init_soft_phong_shader(
|
409 |
+
camera=cameras,
|
410 |
+
blend_params=BlendParams(),
|
411 |
+
device=device),
|
412 |
+
image_size=image_size,
|
413 |
+
faces_per_pixel=faces_per_pixel
|
414 |
+
)
|
415 |
+
fragments_scaled = renderer_scaled.rasterizer(mesh)
|
416 |
+
|
417 |
+
# get UV coordinates for each pixel
|
418 |
+
faces_verts_uvs = verts_uvs[faces.textures_idx]
|
419 |
+
|
420 |
+
pixel_uvs = interpolate_face_attributes(
|
421 |
+
fragments_scaled.pix_to_face, fragments_scaled.bary_coords, faces_verts_uvs
|
422 |
+
) # NxHsxWsxKx2
|
423 |
+
pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(pixel_uvs.shape[-2], pixel_uvs.shape[1], pixel_uvs.shape[2], 2)
|
424 |
+
|
425 |
+
# the update mask has to be on top of the diffusion mask
|
426 |
+
new_mask_image_tensor = transforms.ToTensor()(new_mask_image).to(device).unsqueeze(-1)
|
427 |
+
update_mask_image_tensor = transforms.ToTensor()(update_mask_image).to(device).unsqueeze(-1)
|
428 |
+
|
429 |
+
project_mask_image_tensor = torch.logical_or(update_mask_image_tensor, new_mask_image_tensor).float()
|
430 |
+
project_mask_image = project_mask_image_tensor * 255.
|
431 |
+
project_mask_image = Image.fromarray(project_mask_image[0, :, :, 0].cpu().numpy().astype(np.uint8))
|
432 |
+
|
433 |
+
project_mask_image_scaled = project_mask_image.resize(
|
434 |
+
(image_size, image_size),
|
435 |
+
Image.Resampling.NEAREST
|
436 |
+
)
|
437 |
+
project_mask_image_tensor_scaled = transforms.ToTensor()(project_mask_image_scaled).to(device)
|
438 |
+
|
439 |
+
pixel_uvs_masked = pixel_uvs[project_mask_image_tensor_scaled == 1]
|
440 |
+
|
441 |
+
texture_locations_y, texture_locations_x = get_all_4_locations(
|
442 |
+
(1 - pixel_uvs_masked[:, 1]).reshape(-1) * (uv_size - 1),
|
443 |
+
pixel_uvs_masked[:, 0].reshape(-1) * (uv_size - 1)
|
444 |
+
)
|
445 |
+
|
446 |
+
K = pixel_uvs.shape[0]
|
447 |
+
project_mask_image_tensor_scaled = project_mask_image_tensor_scaled[:, None, :, :, None].repeat(1, 4, 1, 1, 3)
|
448 |
+
|
449 |
+
texture_values = torch.from_numpy(np.array(reference_image.resize((image_size, image_size))))
|
450 |
+
texture_values = texture_values.to(device).unsqueeze(0).expand([4, -1, -1, -1]).unsqueeze(0).expand([K, -1, -1, -1, -1])
|
451 |
+
|
452 |
+
texture_values_masked = texture_values.reshape(-1, 3)[project_mask_image_tensor_scaled.reshape(-1, 3) == 1].reshape(-1, 3)
|
453 |
+
|
454 |
+
# texture
|
455 |
+
texture_tensor = torch.from_numpy(np.array(init_texture)).to(device)
|
456 |
+
texture_tensor[texture_locations_y, texture_locations_x, :] = texture_values_masked
|
457 |
+
|
458 |
+
init_texture = Image.fromarray(texture_tensor.cpu().numpy().astype(np.uint8))
|
459 |
+
|
460 |
+
# update texture cache
|
461 |
+
exist_texture[texture_locations_y, texture_locations_x] = 1
|
462 |
+
|
463 |
+
return init_texture, project_mask_image, exist_texture
|
464 |
+
|
text2tex/lib/render_helper.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from torchvision import transforms
|
11 |
+
from pytorch3d.ops import interpolate_face_attributes
|
12 |
+
from pytorch3d.renderer import (
|
13 |
+
RasterizationSettings,
|
14 |
+
MeshRendererWithFragments,
|
15 |
+
MeshRasterizer,
|
16 |
+
)
|
17 |
+
|
18 |
+
# customized
|
19 |
+
import sys
|
20 |
+
sys.path.append(".")
|
21 |
+
|
22 |
+
|
23 |
+
def init_renderer(camera, shader, image_size, faces_per_pixel):
|
24 |
+
raster_settings = RasterizationSettings(image_size=image_size, faces_per_pixel=faces_per_pixel)
|
25 |
+
renderer = MeshRendererWithFragments(
|
26 |
+
rasterizer=MeshRasterizer(
|
27 |
+
cameras=camera,
|
28 |
+
raster_settings=raster_settings
|
29 |
+
),
|
30 |
+
shader=shader
|
31 |
+
)
|
32 |
+
|
33 |
+
return renderer
|
34 |
+
|
35 |
+
|
36 |
+
@torch.no_grad()
|
37 |
+
def render(mesh, renderer, pad_value=10):
|
38 |
+
def phong_normal_shading(meshes, fragments) -> torch.Tensor:
|
39 |
+
faces = meshes.faces_packed() # (F, 3)
|
40 |
+
vertex_normals = meshes.verts_normals_packed() # (V, 3)
|
41 |
+
faces_normals = vertex_normals[faces]
|
42 |
+
pixel_normals = interpolate_face_attributes(
|
43 |
+
fragments.pix_to_face, fragments.bary_coords, faces_normals
|
44 |
+
)
|
45 |
+
|
46 |
+
return pixel_normals
|
47 |
+
|
48 |
+
def similarity_shading(meshes, fragments):
|
49 |
+
faces = meshes.faces_packed() # (F, 3)
|
50 |
+
vertex_normals = meshes.verts_normals_packed() # (V, 3)
|
51 |
+
faces_normals = vertex_normals[faces]
|
52 |
+
vertices = meshes.verts_packed() # (V, 3)
|
53 |
+
face_positions = vertices[faces]
|
54 |
+
view_directions = torch.nn.functional.normalize((renderer.shader.cameras.get_camera_center().reshape(1, 1, 3) - face_positions), p=2, dim=2)
|
55 |
+
cosine_similarity = torch.nn.CosineSimilarity(dim=2)(faces_normals, view_directions)
|
56 |
+
pixel_similarity = interpolate_face_attributes(
|
57 |
+
fragments.pix_to_face, fragments.bary_coords, cosine_similarity.unsqueeze(-1)
|
58 |
+
)
|
59 |
+
|
60 |
+
return pixel_similarity
|
61 |
+
|
62 |
+
def get_relative_depth_map(fragments, pad_value=pad_value):
|
63 |
+
absolute_depth = fragments.zbuf[..., 0] # B, H, W
|
64 |
+
no_depth = -1
|
65 |
+
|
66 |
+
depth_min, depth_max = absolute_depth[absolute_depth != no_depth].min(), absolute_depth[absolute_depth != no_depth].max()
|
67 |
+
target_min, target_max = 50, 255
|
68 |
+
|
69 |
+
depth_value = absolute_depth[absolute_depth != no_depth]
|
70 |
+
depth_value = depth_max - depth_value # reverse values
|
71 |
+
|
72 |
+
depth_value /= (depth_max - depth_min)
|
73 |
+
depth_value = depth_value * (target_max - target_min) + target_min
|
74 |
+
|
75 |
+
relative_depth = absolute_depth.clone()
|
76 |
+
relative_depth[absolute_depth != no_depth] = depth_value
|
77 |
+
relative_depth[absolute_depth == no_depth] = pad_value # not completely black
|
78 |
+
|
79 |
+
return relative_depth
|
80 |
+
|
81 |
+
|
82 |
+
images, fragments = renderer(mesh)
|
83 |
+
normal_maps = phong_normal_shading(mesh, fragments).squeeze(-2)
|
84 |
+
similarity_maps = similarity_shading(mesh, fragments).squeeze(-2) # -1 - 1
|
85 |
+
depth_maps = get_relative_depth_map(fragments)
|
86 |
+
|
87 |
+
# normalize similarity mask to 0 - 1
|
88 |
+
similarity_maps = torch.abs(similarity_maps) # 0 - 1
|
89 |
+
|
90 |
+
# HACK erode, eliminate isolated dots
|
91 |
+
non_zero_similarity = (similarity_maps > 0).float()
|
92 |
+
non_zero_similarity = (non_zero_similarity * 255.).cpu().numpy().astype(np.uint8)[0]
|
93 |
+
non_zero_similarity = cv2.erode(non_zero_similarity, kernel=np.ones((3, 3), np.uint8), iterations=2)
|
94 |
+
non_zero_similarity = torch.from_numpy(non_zero_similarity).to(similarity_maps.device).unsqueeze(0) / 255.
|
95 |
+
similarity_maps = non_zero_similarity.unsqueeze(-1) * similarity_maps
|
96 |
+
|
97 |
+
return images, normal_maps, similarity_maps, depth_maps, fragments
|
98 |
+
|
99 |
+
|
100 |
+
@torch.no_grad()
|
101 |
+
def check_visible_faces(mesh, fragments):
|
102 |
+
pix_to_face = fragments.pix_to_face
|
103 |
+
|
104 |
+
# Indices of unique visible faces
|
105 |
+
visible_map = pix_to_face.unique() # (num_visible_faces)
|
106 |
+
|
107 |
+
return visible_map
|
108 |
+
|