Spaces:
Runtime error
Runtime error
Commit
·
e62f618
1
Parent(s):
565c7be
delete
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +2 -3
- app.py +0 -284
- assets/logo.png +0 -0
- assets/overview_3.png +0 -0
- assets/radar.png +0 -0
- assets/runtime.png +0 -0
- assets/teaser.png +0 -3
- demos/example_000.png +0 -0
- demos/example_001.png +0 -0
- demos/example_002.png +0 -0
- demos/example_003.png +0 -3
- demos/example_list.txt +0 -2
- infer/__init__.py +0 -28
- infer/gif_render.py +0 -55
- infer/image_to_views.py +0 -81
- infer/rembg.py +0 -26
- infer/text_to_image.py +0 -80
- infer/utils.py +0 -77
- infer/views_to_mesh.py +0 -94
- mvd/__init__.py +0 -0
- mvd/hunyuan3d_mvd_lite_pipeline.py +0 -493
- mvd/hunyuan3d_mvd_std_pipeline.py +0 -471
- mvd/utils.py +0 -85
- requirements.txt +0 -22
- scripts/image_to_3d.sh +0 -8
- scripts/image_to_3d_demo.sh +0 -8
- scripts/image_to_3d_fast.sh +0 -6
- scripts/image_to_3d_fast_demo.sh +0 -6
- scripts/text_to_3d.sh +0 -7
- scripts/text_to_3d_demo.sh +0 -7
- scripts/text_to_3d_fast.sh +0 -6
- scripts/text_to_3d_fast_demo.sh +0 -6
- svrm/.DS_Store +0 -0
- svrm/configs/2024-10-24T22-36-18-project.yaml +0 -32
- svrm/configs/svrm.yaml +0 -32
- svrm/ldm/.DS_Store +0 -0
- svrm/ldm/models/svrm.py +0 -263
- svrm/ldm/modules/attention.py +0 -457
- svrm/ldm/modules/encoders/__init__.py +0 -0
- svrm/ldm/modules/encoders/dinov2/__init__.py +0 -0
- svrm/ldm/modules/encoders/dinov2/hub/__init__.py +0 -0
- svrm/ldm/modules/encoders/dinov2/hub/backbones.py +0 -156
- svrm/ldm/modules/encoders/dinov2/hub/utils.py +0 -39
- svrm/ldm/modules/encoders/dinov2/layers/__init__.py +0 -11
- svrm/ldm/modules/encoders/dinov2/layers/attention.py +0 -89
- svrm/ldm/modules/encoders/dinov2/layers/block.py +0 -269
- svrm/ldm/modules/encoders/dinov2/layers/dino_head.py +0 -58
- svrm/ldm/modules/encoders/dinov2/layers/drop_path.py +0 -34
- svrm/ldm/modules/encoders/dinov2/layers/layer_scale.py +0 -27
- svrm/ldm/modules/encoders/dinov2/layers/mlp.py +0 -40
README.md
CHANGED
@@ -1,11 +1,10 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: 😻
|
4 |
colorFrom: purple
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
short_description: Text-to-3D and Image-to-3D Generation
|
11 |
---
|
|
|
1 |
---
|
2 |
+
title: Image Procesing
|
3 |
emoji: 😻
|
4 |
colorFrom: purple
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.3.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
app.py
DELETED
@@ -1,284 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import warnings
|
3 |
-
from huggingface_hub import hf_hub_download
|
4 |
-
import gradio as gr
|
5 |
-
from glob import glob
|
6 |
-
import shutil
|
7 |
-
import torch
|
8 |
-
import numpy as np
|
9 |
-
from PIL import Image
|
10 |
-
from einops import rearrange
|
11 |
-
import argparse
|
12 |
-
|
13 |
-
# Suppress warnings
|
14 |
-
warnings.simplefilter('ignore', category=UserWarning)
|
15 |
-
warnings.simplefilter('ignore', category=FutureWarning)
|
16 |
-
warnings.simplefilter('ignore', category=DeprecationWarning)
|
17 |
-
|
18 |
-
def download_models():
|
19 |
-
# Create weights directory if it doesn't exist
|
20 |
-
os.makedirs("weights", exist_ok=True)
|
21 |
-
os.makedirs("weights/hunyuanDiT", exist_ok=True)
|
22 |
-
|
23 |
-
# Download Hunyuan3D-1 model
|
24 |
-
try:
|
25 |
-
hf_hub_download(
|
26 |
-
repo_id="tencent/Hunyuan3D-1",
|
27 |
-
local_dir="./weights",
|
28 |
-
resume_download=True
|
29 |
-
)
|
30 |
-
print("Successfully downloaded Hunyuan3D-1 model")
|
31 |
-
except Exception as e:
|
32 |
-
print(f"Error downloading Hunyuan3D-1: {e}")
|
33 |
-
|
34 |
-
# Download HunyuanDiT model
|
35 |
-
try:
|
36 |
-
hf_hub_download(
|
37 |
-
repo_id="Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled",
|
38 |
-
local_dir="./weights/hunyuanDiT",
|
39 |
-
resume_download=True
|
40 |
-
)
|
41 |
-
print("Successfully downloaded HunyuanDiT model")
|
42 |
-
except Exception as e:
|
43 |
-
print(f"Error downloading HunyuanDiT: {e}")
|
44 |
-
|
45 |
-
# Download models before starting the app
|
46 |
-
download_models()
|
47 |
-
|
48 |
-
# Parse arguments
|
49 |
-
parser = argparse.ArgumentParser()
|
50 |
-
parser.add_argument("--use_lite", default=False, action="store_true")
|
51 |
-
parser.add_argument("--mv23d_cfg_path", default="./svrm/configs/svrm.yaml", type=str)
|
52 |
-
parser.add_argument("--mv23d_ckt_path", default="weights/svrm/svrm.safetensors", type=str)
|
53 |
-
parser.add_argument("--text2image_path", default="weights/hunyuanDiT", type=str)
|
54 |
-
parser.add_argument("--save_memory", default=False, action="store_true")
|
55 |
-
parser.add_argument("--device", default="cuda:0", type=str)
|
56 |
-
args = parser.parse_args()
|
57 |
-
|
58 |
-
# Constants
|
59 |
-
CONST_PORT = 8080
|
60 |
-
CONST_MAX_QUEUE = 1
|
61 |
-
CONST_SERVER = '0.0.0.0'
|
62 |
-
|
63 |
-
CONST_HEADER = '''
|
64 |
-
<h2><b>Official 🤗 Gradio Demo</b></h2>
|
65 |
-
<h2><a href='https://github.com/tencent/Hunyuan3D-1' target='_blank'>
|
66 |
-
<b>Hunyuan3D-1.0: A Unified Framework for Text-to-3D and Image-to-3D Generation</b></a></h2>
|
67 |
-
'''
|
68 |
-
|
69 |
-
# Helper functions
|
70 |
-
def get_example_img_list():
|
71 |
-
print('Loading example img list ...')
|
72 |
-
return sorted(glob('./demos/example_*.png'))
|
73 |
-
|
74 |
-
def get_example_txt_list():
|
75 |
-
print('Loading example txt list ...')
|
76 |
-
txt_list = []
|
77 |
-
for line in open('./demos/example_list.txt'):
|
78 |
-
txt_list.append(line.strip())
|
79 |
-
return txt_list
|
80 |
-
|
81 |
-
example_is = get_example_img_list()
|
82 |
-
example_ts = get_example_txt_list()
|
83 |
-
|
84 |
-
# Import required workers
|
85 |
-
from infer import seed_everything, save_gif
|
86 |
-
from infer import Text2Image, Removebg, Image2Views, Views2Mesh, GifRenderer
|
87 |
-
|
88 |
-
# Initialize workers
|
89 |
-
worker_xbg = Removebg()
|
90 |
-
print(f"loading {args.text2image_path}")
|
91 |
-
worker_t2i = Text2Image(
|
92 |
-
pretrain=args.text2image_path,
|
93 |
-
device=args.device,
|
94 |
-
save_memory=args.save_memory
|
95 |
-
)
|
96 |
-
worker_i2v = Image2Views(
|
97 |
-
use_lite=args.use_lite,
|
98 |
-
device=args.device
|
99 |
-
)
|
100 |
-
worker_v23 = Views2Mesh(
|
101 |
-
args.mv23d_cfg_path,
|
102 |
-
args.mv23d_ckt_path,
|
103 |
-
use_lite=args.use_lite,
|
104 |
-
device=args.device
|
105 |
-
)
|
106 |
-
worker_gif = GifRenderer(args.device)
|
107 |
-
|
108 |
-
# Pipeline stages
|
109 |
-
def stage_0_t2i(text, image, seed, step):
|
110 |
-
os.makedirs('./outputs/app_output', exist_ok=True)
|
111 |
-
exists = set(int(_) for _ in os.listdir('./outputs/app_output') if not _.startswith("."))
|
112 |
-
cur_id = min(set(range(30)) - exists) if len(exists) < 30 else 0
|
113 |
-
|
114 |
-
if os.path.exists(f"./outputs/app_output/{(cur_id + 1) % 30}"):
|
115 |
-
shutil.rmtree(f"./outputs/app_output/{(cur_id + 1) % 30}")
|
116 |
-
save_folder = f'./outputs/app_output/{cur_id}'
|
117 |
-
os.makedirs(save_folder, exist_ok=True)
|
118 |
-
|
119 |
-
dst = save_folder + '/img.png'
|
120 |
-
|
121 |
-
if not text:
|
122 |
-
if image is None:
|
123 |
-
return dst, save_folder
|
124 |
-
image.save(dst)
|
125 |
-
return dst, save_folder
|
126 |
-
|
127 |
-
image = worker_t2i(text, seed, step)
|
128 |
-
image.save(dst)
|
129 |
-
dst = worker_xbg(image, save_folder)
|
130 |
-
return dst, save_folder
|
131 |
-
|
132 |
-
def stage_1_xbg(image, save_folder):
|
133 |
-
if isinstance(image, str):
|
134 |
-
image = Image.open(image)
|
135 |
-
dst = save_folder + '/img_nobg.png'
|
136 |
-
rgba = worker_xbg(image)
|
137 |
-
rgba.save(dst)
|
138 |
-
return dst
|
139 |
-
|
140 |
-
def stage_2_i2v(image, seed, step, save_folder):
|
141 |
-
if isinstance(image, str):
|
142 |
-
image = Image.open(image)
|
143 |
-
gif_dst = save_folder + '/views.gif'
|
144 |
-
res_img, pils = worker_i2v(image, seed, step)
|
145 |
-
save_gif(pils, gif_dst)
|
146 |
-
views_img, cond_img = res_img[0], res_img[1]
|
147 |
-
img_array = np.asarray(views_img, dtype=np.uint8)
|
148 |
-
show_img = rearrange(img_array, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
|
149 |
-
show_img = show_img[worker_i2v.order, ...]
|
150 |
-
show_img = rearrange(show_img, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
|
151 |
-
show_img = Image.fromarray(show_img)
|
152 |
-
return views_img, cond_img, show_img
|
153 |
-
|
154 |
-
def stage_3_v23(views_pil, cond_pil, seed, save_folder, target_face_count=30000,
|
155 |
-
do_texture_mapping=True, do_render=True):
|
156 |
-
do_texture_mapping = do_texture_mapping or do_render
|
157 |
-
obj_dst = save_folder + '/mesh_with_colors.obj'
|
158 |
-
glb_dst = save_folder + '/mesh.glb'
|
159 |
-
worker_v23(
|
160 |
-
views_pil,
|
161 |
-
cond_pil,
|
162 |
-
seed=seed,
|
163 |
-
save_folder=save_folder,
|
164 |
-
target_face_count=target_face_count,
|
165 |
-
do_texture_mapping=do_texture_mapping
|
166 |
-
)
|
167 |
-
return obj_dst, glb_dst
|
168 |
-
|
169 |
-
def stage_4_gif(obj_dst, save_folder, do_render_gif=True):
|
170 |
-
if not do_render_gif:
|
171 |
-
return None
|
172 |
-
gif_dst = save_folder + '/output.gif'
|
173 |
-
worker_gif(
|
174 |
-
save_folder + '/mesh.obj',
|
175 |
-
gif_dst_path=gif_dst
|
176 |
-
)
|
177 |
-
return gif_dst
|
178 |
-
|
179 |
-
# Gradio Interface
|
180 |
-
with gr.Blocks() as demo:
|
181 |
-
gr.Markdown(CONST_HEADER)
|
182 |
-
|
183 |
-
with gr.Row(variant="panel"):
|
184 |
-
with gr.Column(scale=2):
|
185 |
-
with gr.Tab("Text to 3D"):
|
186 |
-
with gr.Column():
|
187 |
-
text = gr.TextArea('一只黑白相间的熊猫在白色背景上居中坐着,呈现出卡通风格和可爱氛围。',
|
188 |
-
lines=1, max_lines=10, label='Input text')
|
189 |
-
with gr.Row():
|
190 |
-
textgen_seed = gr.Number(value=0, label="T2I seed", precision=0)
|
191 |
-
textgen_step = gr.Number(value=25, label="T2I step", precision=0)
|
192 |
-
textgen_SEED = gr.Number(value=0, label="Gen seed", precision=0)
|
193 |
-
textgen_STEP = gr.Number(value=50, label="Gen step", precision=0)
|
194 |
-
textgen_max_faces = gr.Number(value=90000, label="max number of faces", precision=0)
|
195 |
-
|
196 |
-
with gr.Row():
|
197 |
-
textgen_do_texture_mapping = gr.Checkbox(label="texture mapping", value=False)
|
198 |
-
textgen_do_render_gif = gr.Checkbox(label="Render gif", value=False)
|
199 |
-
textgen_submit = gr.Button("Generate", variant="primary")
|
200 |
-
|
201 |
-
gr.Examples(examples=example_ts, inputs=[text], label="Txt examples")
|
202 |
-
|
203 |
-
with gr.Tab("Image to 3D"):
|
204 |
-
with gr.Column():
|
205 |
-
input_image = gr.Image(label="Input image", width=256, height=256,
|
206 |
-
type="pil", image_mode="RGBA", sources="upload")
|
207 |
-
with gr.Row():
|
208 |
-
imggen_SEED = gr.Number(value=0, label="Gen seed", precision=0)
|
209 |
-
imggen_STEP = gr.Number(value=50, label="Gen step", precision=0)
|
210 |
-
imggen_max_faces = gr.Number(value=90000, label="max number of faces", precision=0)
|
211 |
-
|
212 |
-
with gr.Row():
|
213 |
-
imggen_do_texture_mapping = gr.Checkbox(label="texture mapping", value=False)
|
214 |
-
imggen_do_render_gif = gr.Checkbox(label="Render gif", value=False)
|
215 |
-
imggen_submit = gr.Button("Generate", variant="primary")
|
216 |
-
|
217 |
-
gr.Examples(examples=example_is, inputs=[input_image], label="Img examples")
|
218 |
-
|
219 |
-
with gr.Column(scale=3):
|
220 |
-
with gr.Tab("rembg image"):
|
221 |
-
rem_bg_image = gr.Image(label="No background image", width=256, height=256,
|
222 |
-
type="pil", image_mode="RGBA")
|
223 |
-
|
224 |
-
with gr.Tab("Multi views"):
|
225 |
-
result_image = gr.Image(label="Multi views", type="pil")
|
226 |
-
with gr.Tab("Obj"):
|
227 |
-
result_3dobj = gr.Model3D(label="Output obj")
|
228 |
-
with gr.Tab("Glb"):
|
229 |
-
result_3dglb = gr.Model3D(label="Output glb")
|
230 |
-
with gr.Tab("GIF"):
|
231 |
-
result_gif = gr.Image(label="Rendered GIF")
|
232 |
-
|
233 |
-
# States
|
234 |
-
none = gr.State(None)
|
235 |
-
save_folder = gr.State()
|
236 |
-
cond_image = gr.State()
|
237 |
-
views_image = gr.State()
|
238 |
-
text_image = gr.State()
|
239 |
-
|
240 |
-
# Event handlers
|
241 |
-
textgen_submit.click(
|
242 |
-
fn=stage_0_t2i,
|
243 |
-
inputs=[text, none, textgen_seed, textgen_step],
|
244 |
-
outputs=[rem_bg_image, save_folder],
|
245 |
-
).success(
|
246 |
-
fn=stage_2_i2v,
|
247 |
-
inputs=[rem_bg_image, textgen_SEED, textgen_STEP, save_folder],
|
248 |
-
outputs=[views_image, cond_image, result_image],
|
249 |
-
).success(
|
250 |
-
fn=stage_3_v23,
|
251 |
-
inputs=[views_image, cond_image, textgen_SEED, save_folder, textgen_max_faces,
|
252 |
-
textgen_do_texture_mapping, textgen_do_render_gif],
|
253 |
-
outputs=[result_3dobj, result_3dglb],
|
254 |
-
).success(
|
255 |
-
fn=stage_4_gif,
|
256 |
-
inputs=[result_3dglb, save_folder, textgen_do_render_gif],
|
257 |
-
outputs=[result_gif],
|
258 |
-
)
|
259 |
-
|
260 |
-
imggen_submit.click(
|
261 |
-
fn=stage_0_t2i,
|
262 |
-
inputs=[none, input_image, textgen_seed, textgen_step],
|
263 |
-
outputs=[text_image, save_folder],
|
264 |
-
).success(
|
265 |
-
fn=stage_1_xbg,
|
266 |
-
inputs=[text_image, save_folder],
|
267 |
-
outputs=[rem_bg_image],
|
268 |
-
).success(
|
269 |
-
fn=stage_2_i2v,
|
270 |
-
inputs=[rem_bg_image, imggen_SEED, imggen_STEP, save_folder],
|
271 |
-
outputs=[views_image, cond_image, result_image],
|
272 |
-
).success(
|
273 |
-
fn=stage_3_v23,
|
274 |
-
inputs=[views_image, cond_image, imggen_SEED, save_folder, imggen_max_faces,
|
275 |
-
imggen_do_texture_mapping, imggen_do_render_gif],
|
276 |
-
outputs=[result_3dobj, result_3dglb],
|
277 |
-
).success(
|
278 |
-
fn=stage_4_gif,
|
279 |
-
inputs=[result_3dglb, save_folder, imggen_do_render_gif],
|
280 |
-
outputs=[result_gif],
|
281 |
-
)
|
282 |
-
|
283 |
-
demo.queue(max_size=CONST_MAX_QUEUE)
|
284 |
-
demo.launch(server_name=CONST_SERVER, server_port=CONST_PORT)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assets/logo.png
DELETED
Binary file (314 kB)
|
|
assets/overview_3.png
DELETED
Binary file (271 kB)
|
|
assets/radar.png
DELETED
Binary file (122 kB)
|
|
assets/runtime.png
DELETED
Binary file (38.4 kB)
|
|
assets/teaser.png
DELETED
Git LFS Details
|
demos/example_000.png
DELETED
Binary file (659 kB)
|
|
demos/example_001.png
DELETED
Binary file (817 kB)
|
|
demos/example_002.png
DELETED
Binary file (339 kB)
|
|
demos/example_003.png
DELETED
Git LFS Details
|
demos/example_list.txt
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
a pot of green plants grows in a red flower pot.
|
2 |
-
a lovely rabbit eating carrots
|
|
|
|
|
|
infer/__init__.py
DELETED
@@ -1,28 +0,0 @@
|
|
1 |
-
# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
2 |
-
# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
3 |
-
|
4 |
-
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
5 |
-
# The below software and/or models in this distribution may have been
|
6 |
-
# modified by THL A29 Limited ("Tencent Modifications").
|
7 |
-
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
8 |
-
|
9 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
10 |
-
# except for the third-party components listed below.
|
11 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
12 |
-
# in the repsective licenses of these third-party components.
|
13 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
14 |
-
# components and must ensure that the usage of the third party components adheres to
|
15 |
-
# all relevant laws and regulations.
|
16 |
-
|
17 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
18 |
-
# their software and algorithms, including trained model weights, parameters (including
|
19 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
20 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
21 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
22 |
-
|
23 |
-
from .utils import seed_everything, timing_decorator, auto_amp_inference
|
24 |
-
from .rembg import Removebg
|
25 |
-
from .text_to_image import Text2Image
|
26 |
-
from .image_to_views import Image2Views, save_gif
|
27 |
-
from .views_to_mesh import Views2Mesh
|
28 |
-
from .gif_render import GifRenderer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
infer/gif_render.py
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
2 |
-
# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
3 |
-
|
4 |
-
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
5 |
-
# The below software and/or models in this distribution may have been
|
6 |
-
# modified by THL A29 Limited ("Tencent Modifications").
|
7 |
-
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
8 |
-
|
9 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
10 |
-
# except for the third-party components listed below.
|
11 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
12 |
-
# in the repsective licenses of these third-party components.
|
13 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
14 |
-
# components and must ensure that the usage of the third party components adheres to
|
15 |
-
# all relevant laws and regulations.
|
16 |
-
|
17 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
18 |
-
# their software and algorithms, including trained model weights, parameters (including
|
19 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
20 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
21 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
22 |
-
|
23 |
-
from svrm.ldm.vis_util import render
|
24 |
-
from .utils import seed_everything, timing_decorator
|
25 |
-
|
26 |
-
class GifRenderer():
|
27 |
-
'''
|
28 |
-
render frame(s) of mesh using pytorch3d
|
29 |
-
'''
|
30 |
-
def __init__(self, device="cuda:0"):
|
31 |
-
self.device = device
|
32 |
-
|
33 |
-
@timing_decorator("gif render")
|
34 |
-
def __call__(
|
35 |
-
self,
|
36 |
-
obj_filename,
|
37 |
-
elev=0,
|
38 |
-
azim=0,
|
39 |
-
resolution=512,
|
40 |
-
gif_dst_path='',
|
41 |
-
n_views=120,
|
42 |
-
fps=30,
|
43 |
-
rgb=True
|
44 |
-
):
|
45 |
-
render(
|
46 |
-
obj_filename,
|
47 |
-
elev=elev,
|
48 |
-
azim=azim,
|
49 |
-
resolution=resolution,
|
50 |
-
gif_dst_path=gif_dst_path,
|
51 |
-
n_views=n_views,
|
52 |
-
fps=fps,
|
53 |
-
device=self.device,
|
54 |
-
rgb=rgb
|
55 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
infer/image_to_views.py
DELETED
@@ -1,81 +0,0 @@
|
|
1 |
-
# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
2 |
-
# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
3 |
-
|
4 |
-
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
5 |
-
# The below software and/or models in this distribution may have been
|
6 |
-
# modified by THL A29 Limited ("Tencent Modifications").
|
7 |
-
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
8 |
-
|
9 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
10 |
-
# except for the third-party components listed below.
|
11 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
12 |
-
# in the repsective licenses of these third-party components.
|
13 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
14 |
-
# components and must ensure that the usage of the third party components adheres to
|
15 |
-
# all relevant laws and regulations.
|
16 |
-
|
17 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
18 |
-
# their software and algorithms, including trained model weights, parameters (including
|
19 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
20 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
21 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
22 |
-
|
23 |
-
import os
|
24 |
-
import time
|
25 |
-
import torch
|
26 |
-
import random
|
27 |
-
import numpy as np
|
28 |
-
from PIL import Image
|
29 |
-
from einops import rearrange
|
30 |
-
from PIL import Image, ImageSequence
|
31 |
-
|
32 |
-
from .utils import seed_everything, timing_decorator, auto_amp_inference
|
33 |
-
from .utils import get_parameter_number, set_parameter_grad_false
|
34 |
-
from mvd.hunyuan3d_mvd_std_pipeline import HunYuan3D_MVD_Std_Pipeline
|
35 |
-
from mvd.hunyuan3d_mvd_lite_pipeline import Hunyuan3d_MVD_Lite_Pipeline
|
36 |
-
|
37 |
-
|
38 |
-
def save_gif(pils, save_path, df=False):
|
39 |
-
# save a list of PIL.Image to gif
|
40 |
-
spf = 4000 / len(pils)
|
41 |
-
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
42 |
-
pils[0].save(save_path, format="GIF", save_all=True, append_images=pils[1:], duration=spf, loop=0)
|
43 |
-
return save_path
|
44 |
-
|
45 |
-
|
46 |
-
class Image2Views():
|
47 |
-
def __init__(self, device="cuda:0", use_lite=False):
|
48 |
-
self.device = device
|
49 |
-
if use_lite:
|
50 |
-
self.pipe = Hunyuan3d_MVD_Lite_Pipeline.from_pretrained(
|
51 |
-
"./weights/mvd_lite",
|
52 |
-
torch_dtype = torch.float16,
|
53 |
-
use_safetensors = True,
|
54 |
-
)
|
55 |
-
else:
|
56 |
-
self.pipe = HunYuan3D_MVD_Std_Pipeline.from_pretrained(
|
57 |
-
"./weights/mvd_std",
|
58 |
-
torch_dtype = torch.float16,
|
59 |
-
use_safetensors = True,
|
60 |
-
)
|
61 |
-
self.pipe = self.pipe.to(device)
|
62 |
-
self.order = [0, 1, 2, 3, 4, 5] if use_lite else [0, 2, 4, 5, 3, 1]
|
63 |
-
set_parameter_grad_false(self.pipe.unet)
|
64 |
-
print('image2views unet model', get_parameter_number(self.pipe.unet))
|
65 |
-
|
66 |
-
@torch.no_grad()
|
67 |
-
@timing_decorator("image to views")
|
68 |
-
@auto_amp_inference
|
69 |
-
def __call__(self, pil_img, seed=0, steps=50, guidance_scale=2.0, guidance_curve=lambda t:2.0):
|
70 |
-
seed_everything(seed)
|
71 |
-
generator = torch.Generator(device=self.device)
|
72 |
-
res_img = self.pipe(pil_img,
|
73 |
-
num_inference_steps=steps,
|
74 |
-
guidance_scale=guidance_scale,
|
75 |
-
guidance_curve=guidance_curve,
|
76 |
-
generat=generator).images
|
77 |
-
show_image = rearrange(np.asarray(res_img[0], dtype=np.uint8), '(n h) (m w) c -> (n m) h w c', n=3, m=2)
|
78 |
-
pils = [res_img[1]]+[Image.fromarray(show_image[idx]) for idx in self.order]
|
79 |
-
torch.cuda.empty_cache()
|
80 |
-
return res_img, pils
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
infer/rembg.py
DELETED
@@ -1,26 +0,0 @@
|
|
1 |
-
from rembg import remove, new_session
|
2 |
-
from .utils import timing_decorator
|
3 |
-
|
4 |
-
class Removebg():
|
5 |
-
def __init__(self, name="u2net"):
|
6 |
-
'''
|
7 |
-
name: rembg
|
8 |
-
'''
|
9 |
-
self.session = new_session(name)
|
10 |
-
|
11 |
-
@timing_decorator("remove background")
|
12 |
-
def __call__(self, rgb_img, force=False):
|
13 |
-
'''
|
14 |
-
inputs:
|
15 |
-
rgb_img: PIL.Image, with RGB mode expected
|
16 |
-
force: bool, input is RGBA mode
|
17 |
-
return:
|
18 |
-
rgba_img: PIL.Image with RGBA mode
|
19 |
-
'''
|
20 |
-
if rgb_img.mode == "RGBA":
|
21 |
-
if force:
|
22 |
-
rgb_img = rgb_img.convert("RGB")
|
23 |
-
else:
|
24 |
-
return rgb_img
|
25 |
-
rgba_img = remove(rgb_img, session=self.session)
|
26 |
-
return rgba_img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
infer/text_to_image.py
DELETED
@@ -1,80 +0,0 @@
|
|
1 |
-
# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
2 |
-
# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
3 |
-
|
4 |
-
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
5 |
-
# The below software and/or models in this distribution may have been
|
6 |
-
# modified by THL A29 Limited ("Tencent Modifications").
|
7 |
-
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
8 |
-
|
9 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
10 |
-
# except for the third-party components listed below.
|
11 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
12 |
-
# in the repsective licenses of these third-party components.
|
13 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
14 |
-
# components and must ensure that the usage of the third party components adheres to
|
15 |
-
# all relevant laws and regulations.
|
16 |
-
|
17 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
18 |
-
# their software and algorithms, including trained model weights, parameters (including
|
19 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
20 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
21 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
22 |
-
|
23 |
-
import torch
|
24 |
-
from .utils import seed_everything, timing_decorator, auto_amp_inference
|
25 |
-
from .utils import get_parameter_number, set_parameter_grad_false
|
26 |
-
from diffusers import HunyuanDiTPipeline, AutoPipelineForText2Image
|
27 |
-
|
28 |
-
class Text2Image():
|
29 |
-
def __init__(self, pretrain="weights/hunyuanDiT", device="cuda:0", save_memory=False):
|
30 |
-
'''
|
31 |
-
save_memory: if GPU memory is low, can set it
|
32 |
-
'''
|
33 |
-
self.save_memory = save_memory
|
34 |
-
self.device = device
|
35 |
-
self.pipe = AutoPipelineForText2Image.from_pretrained(
|
36 |
-
pretrain,
|
37 |
-
torch_dtype = torch.float16,
|
38 |
-
enable_pag = True,
|
39 |
-
pag_applied_layers = ["blocks.(16|17|18|19)"]
|
40 |
-
)
|
41 |
-
set_parameter_grad_false(self.pipe.transformer)
|
42 |
-
print('text2image transformer model', get_parameter_number(self.pipe.transformer))
|
43 |
-
if not save_memory:
|
44 |
-
self.pipe = self.pipe.to(device)
|
45 |
-
self.neg_txt = "文本,特写,裁剪,出框,最差质量,低质量,JPEG伪影,PGLY,重复,病态,残缺,多余的手指,变异的手," \
|
46 |
-
"画得不好的手,画得不好的脸,变异,畸形,模糊,脱水,糟糕的解剖学,糟糕的比例,多余的肢体,克隆的脸," \
|
47 |
-
"毁容,恶心的比例,畸形的肢体,缺失的手臂,缺失的腿,额外的手臂,额外的腿,融合的手指,手指太多,长脖子"
|
48 |
-
|
49 |
-
@torch.no_grad()
|
50 |
-
@timing_decorator('text to image')
|
51 |
-
@auto_amp_inference
|
52 |
-
def __call__(self, *args, **kwargs):
|
53 |
-
if self.save_memory:
|
54 |
-
self.pipe = self.pipe.to(self.device)
|
55 |
-
torch.cuda.empty_cache()
|
56 |
-
res = self.call(*args, **kwargs)
|
57 |
-
self.pipe = self.pipe.to("cpu")
|
58 |
-
else:
|
59 |
-
res = self.call(*args, **kwargs)
|
60 |
-
torch.cuda.empty_cache()
|
61 |
-
return res
|
62 |
-
|
63 |
-
def call(self, prompt, seed=0, steps=25):
|
64 |
-
'''
|
65 |
-
inputs:
|
66 |
-
prompr: str
|
67 |
-
seed: int
|
68 |
-
steps: int
|
69 |
-
return:
|
70 |
-
rgb: PIL.Image
|
71 |
-
'''
|
72 |
-
prompt = prompt + ",白色背景,3D风格,最佳质量"
|
73 |
-
seed_everything(seed)
|
74 |
-
generator = torch.Generator(device=self.device)
|
75 |
-
if seed is not None: generator = generator.manual_seed(int(seed))
|
76 |
-
rgb = self.pipe(prompt=prompt, negative_prompt=self.neg_txt, num_inference_steps=steps,
|
77 |
-
pag_scale=1.3, width=1024, height=1024, generator=generator, return_dict=False)[0][0]
|
78 |
-
torch.cuda.empty_cache()
|
79 |
-
return rgb
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
infer/utils.py
DELETED
@@ -1,77 +0,0 @@
|
|
1 |
-
# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
2 |
-
# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
3 |
-
|
4 |
-
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
5 |
-
# The below software and/or models in this distribution may have been
|
6 |
-
# modified by THL A29 Limited ("Tencent Modifications").
|
7 |
-
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
8 |
-
|
9 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
10 |
-
# except for the third-party components listed below.
|
11 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
12 |
-
# in the repsective licenses of these third-party components.
|
13 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
14 |
-
# components and must ensure that the usage of the third party components adheres to
|
15 |
-
# all relevant laws and regulations.
|
16 |
-
|
17 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
18 |
-
# their software and algorithms, including trained model weights, parameters (including
|
19 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
20 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
21 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
22 |
-
|
23 |
-
import os
|
24 |
-
import time
|
25 |
-
import random
|
26 |
-
import numpy as np
|
27 |
-
import torch
|
28 |
-
from torch.cuda.amp import autocast, GradScaler
|
29 |
-
from functools import wraps
|
30 |
-
|
31 |
-
def seed_everything(seed):
|
32 |
-
'''
|
33 |
-
seed everthing
|
34 |
-
'''
|
35 |
-
random.seed(seed)
|
36 |
-
np.random.seed(seed)
|
37 |
-
torch.manual_seed(seed)
|
38 |
-
os.environ["PL_GLOBAL_SEED"] = str(seed)
|
39 |
-
|
40 |
-
def timing_decorator(category: str):
|
41 |
-
'''
|
42 |
-
timing_decorator: record time
|
43 |
-
'''
|
44 |
-
def decorator(func):
|
45 |
-
func.call_count = 0
|
46 |
-
@wraps(func)
|
47 |
-
def wrapper(*args, **kwargs):
|
48 |
-
start_time = time.time()
|
49 |
-
result = func(*args, **kwargs)
|
50 |
-
end_time = time.time()
|
51 |
-
elapsed_time = end_time - start_time
|
52 |
-
func.call_count += 1
|
53 |
-
print(f"[HunYuan3D]-[{category}], cost time: {elapsed_time:.4f}s") # huiwen
|
54 |
-
return result
|
55 |
-
return wrapper
|
56 |
-
return decorator
|
57 |
-
|
58 |
-
def auto_amp_inference(func):
|
59 |
-
'''
|
60 |
-
with torch.cuda.amp.autocast()"
|
61 |
-
xxx
|
62 |
-
'''
|
63 |
-
@wraps(func)
|
64 |
-
def wrapper(*args, **kwargs):
|
65 |
-
with autocast():
|
66 |
-
output = func(*args, **kwargs)
|
67 |
-
return output
|
68 |
-
return wrapper
|
69 |
-
|
70 |
-
def get_parameter_number(model):
|
71 |
-
total_num = sum(p.numel() for p in model.parameters())
|
72 |
-
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
73 |
-
return {'Total': total_num, 'Trainable': trainable_num}
|
74 |
-
|
75 |
-
def set_parameter_grad_false(model):
|
76 |
-
for p in model.parameters():
|
77 |
-
p.requires_grad = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
infer/views_to_mesh.py
DELETED
@@ -1,94 +0,0 @@
|
|
1 |
-
# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
2 |
-
# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
3 |
-
|
4 |
-
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
5 |
-
# The below software and/or models in this distribution may have been
|
6 |
-
# modified by THL A29 Limited ("Tencent Modifications").
|
7 |
-
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
8 |
-
|
9 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
10 |
-
# except for the third-party components listed below.
|
11 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
12 |
-
# in the repsective licenses of these third-party components.
|
13 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
14 |
-
# components and must ensure that the usage of the third party components adheres to
|
15 |
-
# all relevant laws and regulations.
|
16 |
-
|
17 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
18 |
-
# their software and algorithms, including trained model weights, parameters (including
|
19 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
20 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
21 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
22 |
-
|
23 |
-
import os
|
24 |
-
import time
|
25 |
-
import torch
|
26 |
-
import random
|
27 |
-
import numpy as np
|
28 |
-
from PIL import Image
|
29 |
-
from einops import rearrange
|
30 |
-
from PIL import Image, ImageSequence
|
31 |
-
|
32 |
-
from .utils import seed_everything, timing_decorator, auto_amp_inference
|
33 |
-
from .utils import get_parameter_number, set_parameter_grad_false
|
34 |
-
from svrm.predictor import MV23DPredictor
|
35 |
-
|
36 |
-
|
37 |
-
class Views2Mesh():
|
38 |
-
def __init__(self, mv23d_cfg_path, mv23d_ckt_path, device="cuda:0", use_lite=False):
|
39 |
-
'''
|
40 |
-
mv23d_cfg_path: config yaml file
|
41 |
-
mv23d_ckt_path: path to ckpt
|
42 |
-
use_lite:
|
43 |
-
'''
|
44 |
-
self.mv23d_predictor = MV23DPredictor(mv23d_ckt_path, mv23d_cfg_path, device=device)
|
45 |
-
self.mv23d_predictor.model.eval()
|
46 |
-
self.order = [0, 1, 2, 3, 4, 5] if use_lite else [0, 2, 4, 5, 3, 1]
|
47 |
-
set_parameter_grad_false(self.mv23d_predictor.model)
|
48 |
-
print('view2mesh model', get_parameter_number(self.mv23d_predictor.model))
|
49 |
-
|
50 |
-
@torch.no_grad()
|
51 |
-
@timing_decorator("views to mesh")
|
52 |
-
@auto_amp_inference
|
53 |
-
def __call__(
|
54 |
-
self,
|
55 |
-
views_pil=None,
|
56 |
-
cond_pil=None,
|
57 |
-
gif_pil=None,
|
58 |
-
seed=0,
|
59 |
-
target_face_count = 10000,
|
60 |
-
do_texture_mapping = True,
|
61 |
-
save_folder='./outputs/test'
|
62 |
-
):
|
63 |
-
'''
|
64 |
-
can set views_pil, cond_pil simutaously or set gif_pil only
|
65 |
-
seed: int
|
66 |
-
target_face_count: int
|
67 |
-
save_folder: path to save mesh files
|
68 |
-
'''
|
69 |
-
save_dir = save_folder
|
70 |
-
os.makedirs(save_dir, exist_ok=True)
|
71 |
-
|
72 |
-
if views_pil is not None and cond_pil is not None:
|
73 |
-
show_image = rearrange(np.asarray(views_pil, dtype=np.uint8),
|
74 |
-
'(n h) (m w) c -> (n m) h w c', n=3, m=2)
|
75 |
-
views = [Image.fromarray(show_image[idx]) for idx in self.order]
|
76 |
-
image_list = [cond_pil]+ views
|
77 |
-
image_list = [img.convert('RGB') for img in image_list]
|
78 |
-
elif gif_pil is not None:
|
79 |
-
image_list = [img.convert('RGB') for img in ImageSequence.Iterator(gif_pil)]
|
80 |
-
|
81 |
-
image_input = image_list[0]
|
82 |
-
image_list = image_list[1:] + image_list[:1]
|
83 |
-
|
84 |
-
seed_everything(seed)
|
85 |
-
self.mv23d_predictor.predict(
|
86 |
-
image_list,
|
87 |
-
save_dir = save_dir,
|
88 |
-
image_input = image_input,
|
89 |
-
target_face_count = target_face_count,
|
90 |
-
do_texture_mapping = do_texture_mapping
|
91 |
-
)
|
92 |
-
torch.cuda.empty_cache()
|
93 |
-
return save_dir
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mvd/__init__.py
DELETED
File without changes
|
mvd/hunyuan3d_mvd_lite_pipeline.py
DELETED
@@ -1,493 +0,0 @@
|
|
1 |
-
# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
2 |
-
# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
3 |
-
|
4 |
-
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
5 |
-
# The below software and/or models in this distribution may have been
|
6 |
-
# modified by THL A29 Limited ("Tencent Modifications").
|
7 |
-
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
8 |
-
|
9 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
10 |
-
# except for the third-party components listed below.
|
11 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
12 |
-
# in the repsective licenses of these third-party components.
|
13 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
14 |
-
# components and must ensure that the usage of the third party components adheres to
|
15 |
-
# all relevant laws and regulations.
|
16 |
-
|
17 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
18 |
-
# their software and algorithms, including trained model weights, parameters (including
|
19 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
20 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
21 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
22 |
-
|
23 |
-
import math
|
24 |
-
import numpy
|
25 |
-
import torch
|
26 |
-
import inspect
|
27 |
-
import warnings
|
28 |
-
from PIL import Image
|
29 |
-
from einops import rearrange
|
30 |
-
import torch.nn.functional as F
|
31 |
-
from diffusers.utils.torch_utils import randn_tensor
|
32 |
-
from diffusers.configuration_utils import FrozenDict
|
33 |
-
from diffusers.image_processor import VaeImageProcessor
|
34 |
-
from typing import Any, Callable, Dict, List, Optional, Union
|
35 |
-
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
36 |
-
from diffusers.schedulers import KarrasDiffusionSchedulers
|
37 |
-
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
38 |
-
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
39 |
-
from diffusers import DDPMScheduler, EulerAncestralDiscreteScheduler, ImagePipelineOutput
|
40 |
-
from diffusers.loaders import (
|
41 |
-
FromSingleFileMixin,
|
42 |
-
LoraLoaderMixin,
|
43 |
-
TextualInversionLoaderMixin
|
44 |
-
)
|
45 |
-
from transformers import (
|
46 |
-
CLIPImageProcessor,
|
47 |
-
CLIPTextModel,
|
48 |
-
CLIPTokenizer,
|
49 |
-
CLIPVisionModelWithProjection
|
50 |
-
)
|
51 |
-
from diffusers.models.attention_processor import (
|
52 |
-
Attention,
|
53 |
-
AttnProcessor,
|
54 |
-
XFormersAttnProcessor,
|
55 |
-
AttnProcessor2_0
|
56 |
-
)
|
57 |
-
|
58 |
-
from .utils import to_rgb_image, white_out_background, recenter_img
|
59 |
-
|
60 |
-
|
61 |
-
EXAMPLE_DOC_STRING = """
|
62 |
-
Examples:
|
63 |
-
```py
|
64 |
-
>>> import torch
|
65 |
-
>>> from here import Hunyuan3d_MVD_Qing_Pipeline
|
66 |
-
|
67 |
-
>>> pipe = Hunyuan3d_MVD_Qing_Pipeline.from_pretrained(
|
68 |
-
... "Tencent-Hunyuan-3D/MVD-Qing", torch_dtype=torch.float16
|
69 |
-
... )
|
70 |
-
>>> pipe.to("cuda")
|
71 |
-
|
72 |
-
>>> img = Image.open("demo.png")
|
73 |
-
>>> res_img = pipe(img).images[0]
|
74 |
-
"""
|
75 |
-
|
76 |
-
def unscale_latents(latents): return latents / 0.75 + 0.22
|
77 |
-
def unscale_image (image ): return image / 0.50 * 0.80
|
78 |
-
|
79 |
-
|
80 |
-
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
81 |
-
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
82 |
-
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
83 |
-
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
84 |
-
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
85 |
-
return noise_cfg
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
class ReferenceOnlyAttnProc(torch.nn.Module):
|
90 |
-
# reference attention
|
91 |
-
def __init__(self, chained_proc, enabled=False, name=None):
|
92 |
-
super().__init__()
|
93 |
-
self.enabled = enabled
|
94 |
-
self.chained_proc = chained_proc
|
95 |
-
self.name = name
|
96 |
-
|
97 |
-
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, mode="w", ref_dict=None):
|
98 |
-
if encoder_hidden_states is None: encoder_hidden_states = hidden_states
|
99 |
-
if self.enabled:
|
100 |
-
if mode == 'w':
|
101 |
-
ref_dict[self.name] = encoder_hidden_states
|
102 |
-
elif mode == 'r':
|
103 |
-
encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1)
|
104 |
-
res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask)
|
105 |
-
return res
|
106 |
-
|
107 |
-
|
108 |
-
# class RowWiseAttnProcessor2_0:
|
109 |
-
# def __call__(self, attn,
|
110 |
-
# hidden_states,
|
111 |
-
# encoder_hidden_states=None,
|
112 |
-
# attention_mask=None,
|
113 |
-
# temb=None,
|
114 |
-
# num_views=6,
|
115 |
-
# *args,
|
116 |
-
# **kwargs):
|
117 |
-
# residual = hidden_states
|
118 |
-
# if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb)
|
119 |
-
|
120 |
-
# input_ndim = hidden_states.ndim
|
121 |
-
# if input_ndim == 4:
|
122 |
-
# batch_size, channel, height, width = hidden_states.shape
|
123 |
-
# hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
124 |
-
|
125 |
-
# if encoder_hidden_states is None:
|
126 |
-
# batch_size, sequence_length, _ = hidden_states.shape
|
127 |
-
# else:
|
128 |
-
# batch_size, sequence_length, _ = encoder_hidden_states.shape
|
129 |
-
|
130 |
-
# if attention_mask is not None:
|
131 |
-
# attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
132 |
-
# attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
133 |
-
# if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
134 |
-
|
135 |
-
# query = attn.to_q(hidden_states)
|
136 |
-
# if encoder_hidden_states is None: encoder_hidden_states = hidden_states
|
137 |
-
# elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
138 |
-
|
139 |
-
# # encoder_hidden_states [B, 6hw+hw, C] if ref att
|
140 |
-
# key = attn.to_k(encoder_hidden_states) # [B, Vhw+hw, C]
|
141 |
-
# value = attn.to_v(encoder_hidden_states) # [B, Vhw+hw, C]
|
142 |
-
|
143 |
-
# mv_flag = hidden_states.shape[1] < encoder_hidden_states.shape[1] and encoder_hidden_states.shape[1] != 77
|
144 |
-
# if mv_flag:
|
145 |
-
# target_size = int(math.sqrt(hidden_states.shape[1] // num_views))
|
146 |
-
# assert target_size ** 2 * num_views == hidden_states.shape[1]
|
147 |
-
|
148 |
-
# gen_key = key[:, :num_views*target_size*target_size, :]
|
149 |
-
# ref_key = key[:, num_views*target_size*target_size:, :]
|
150 |
-
# gen_value = value[:, :num_views*target_size*target_size, :]
|
151 |
-
# ref_value = value[:, num_views*target_size*target_size:, :]
|
152 |
-
|
153 |
-
# # rowwise attention
|
154 |
-
# query, gen_key, gen_value = \
|
155 |
-
# rearrange( query, "b (v1 h v2 w) c -> (b h) (v1 v2 w) c",
|
156 |
-
# v1=num_views//2, v2=2, h=target_size, w=target_size), \
|
157 |
-
# rearrange( gen_key, "b (v1 h v2 w) c -> (b h) (v1 v2 w) c",
|
158 |
-
# v1=num_views//2, v2=2, h=target_size, w=target_size), \
|
159 |
-
# rearrange(gen_value, "b (v1 h v2 w) c -> (b h) (v1 v2 w) c",
|
160 |
-
# v1=num_views//2, v2=2, h=target_size, w=target_size)
|
161 |
-
|
162 |
-
# inner_dim = key.shape[-1]
|
163 |
-
# ref_size = int(math.sqrt(ref_key.shape[1]))
|
164 |
-
# ref_key_expanded = ref_key.view(batch_size, 1, ref_size * ref_size, inner_dim)
|
165 |
-
# ref_key_expanded = ref_key_expanded.expand(-1, target_size, -1, -1).contiguous()
|
166 |
-
# ref_key_expanded = ref_key_expanded.view(batch_size * target_size, ref_size * ref_size, inner_dim)
|
167 |
-
# key = torch.cat([ gen_key, ref_key_expanded], dim=1)
|
168 |
-
|
169 |
-
# ref_value_expanded = ref_value.view(batch_size, 1, ref_size * ref_size, inner_dim)
|
170 |
-
# ref_value_expanded = ref_value_expanded.expand(-1, target_size, -1, -1).contiguous()
|
171 |
-
# ref_value_expanded = ref_value_expanded.view(batch_size * target_size, ref_size * ref_size, inner_dim)
|
172 |
-
# value = torch.cat([gen_value, ref_value_expanded], dim=1)
|
173 |
-
# h = target_size
|
174 |
-
# else:
|
175 |
-
# target_size = int(math.sqrt(hidden_states.shape[1]))
|
176 |
-
# h = 1
|
177 |
-
# num_views = 1
|
178 |
-
|
179 |
-
# inner_dim = key.shape[-1]
|
180 |
-
# head_dim = inner_dim // attn.heads
|
181 |
-
|
182 |
-
# query = query.view(batch_size * h, -1, attn.heads, head_dim).transpose(1, 2)
|
183 |
-
# key = key.view(batch_size * h, -1, attn.heads, head_dim).transpose(1, 2)
|
184 |
-
# value = value.view(batch_size * h, -1, attn.heads, head_dim).transpose(1, 2)
|
185 |
-
|
186 |
-
# hidden_states = F.scaled_dot_product_attention(query, key, value,
|
187 |
-
# attn_mask=attention_mask,
|
188 |
-
# dropout_p=0.0,
|
189 |
-
# is_causal=False)
|
190 |
-
# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size * h,
|
191 |
-
# -1,
|
192 |
-
# attn.heads * head_dim).to(query.dtype)
|
193 |
-
# hidden_states = attn.to_out[1](attn.to_out[0](hidden_states))
|
194 |
-
|
195 |
-
# if mv_flag: hidden_states = rearrange(hidden_states, "(b h) (v1 v2 w) c -> b (v1 h v2 w) c",
|
196 |
-
# b=batch_size, v1=num_views//2,
|
197 |
-
# v2=2, h=target_size, w=target_size)
|
198 |
-
|
199 |
-
# if input_ndim == 4:
|
200 |
-
# hidden_states = hidden_states.transpose(-1, -2)
|
201 |
-
# hidden_states = hidden_states.reshape(batch_size,
|
202 |
-
# channel,
|
203 |
-
# target_size,
|
204 |
-
# target_size)
|
205 |
-
# if attn.residual_connection: hidden_states = hidden_states + residual
|
206 |
-
# hidden_states = hidden_states / attn.rescale_output_factor
|
207 |
-
# return hidden_states
|
208 |
-
|
209 |
-
|
210 |
-
class RefOnlyNoisedUNet(torch.nn.Module):
|
211 |
-
def __init__(self, unet, train_sched, val_sched):
|
212 |
-
super().__init__()
|
213 |
-
self.unet = unet
|
214 |
-
self.train_sched = train_sched
|
215 |
-
self.val_sched = val_sched
|
216 |
-
|
217 |
-
unet_lora_attn_procs = dict()
|
218 |
-
for name, _ in unet.attn_processors.items():
|
219 |
-
unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(AttnProcessor2_0(),
|
220 |
-
enabled=name.endswith("attn1.processor"),
|
221 |
-
name=name)
|
222 |
-
unet.set_attn_processor(unet_lora_attn_procs)
|
223 |
-
|
224 |
-
def __getattr__(self, name: str):
|
225 |
-
try:
|
226 |
-
return super().__getattr__(name)
|
227 |
-
except AttributeError:
|
228 |
-
return getattr(self.unet, name)
|
229 |
-
|
230 |
-
def forward(self, sample, timestep, encoder_hidden_states, *args, cross_attention_kwargs, **kwargs):
|
231 |
-
cond_lat = cross_attention_kwargs['cond_lat']
|
232 |
-
noise = torch.randn_like(cond_lat)
|
233 |
-
if self.training:
|
234 |
-
noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep)
|
235 |
-
noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep)
|
236 |
-
else:
|
237 |
-
noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1))
|
238 |
-
noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
|
239 |
-
|
240 |
-
ref_dict = {}
|
241 |
-
self.unet(noisy_cond_lat,
|
242 |
-
timestep,
|
243 |
-
encoder_hidden_states,
|
244 |
-
*args,
|
245 |
-
cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict),
|
246 |
-
**kwargs)
|
247 |
-
return self.unet(sample,
|
248 |
-
timestep,
|
249 |
-
encoder_hidden_states,
|
250 |
-
*args,
|
251 |
-
cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict),
|
252 |
-
**kwargs)
|
253 |
-
|
254 |
-
|
255 |
-
class Hunyuan3d_MVD_Lite_Pipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin):
|
256 |
-
def __init__(
|
257 |
-
self,
|
258 |
-
vae: AutoencoderKL,
|
259 |
-
text_encoder: CLIPTextModel,
|
260 |
-
tokenizer: CLIPTokenizer,
|
261 |
-
unet: UNet2DConditionModel,
|
262 |
-
scheduler: KarrasDiffusionSchedulers,
|
263 |
-
vision_encoder: CLIPVisionModelWithProjection,
|
264 |
-
feature_extractor_clip: CLIPImageProcessor,
|
265 |
-
feature_extractor_vae: CLIPImageProcessor,
|
266 |
-
ramping_coefficients: Optional[list] = None,
|
267 |
-
safety_checker=None,
|
268 |
-
):
|
269 |
-
DiffusionPipeline.__init__(self)
|
270 |
-
self.register_modules(
|
271 |
-
vae=vae,
|
272 |
-
unet=unet,
|
273 |
-
tokenizer=tokenizer,
|
274 |
-
scheduler=scheduler,
|
275 |
-
text_encoder=text_encoder,
|
276 |
-
vision_encoder=vision_encoder,
|
277 |
-
feature_extractor_vae=feature_extractor_vae,
|
278 |
-
feature_extractor_clip=feature_extractor_clip)
|
279 |
-
'''
|
280 |
-
rewrite the stable diffusion pipeline
|
281 |
-
vae: vae
|
282 |
-
unet: unet
|
283 |
-
tokenizer: tokenizer
|
284 |
-
scheduler: scheduler
|
285 |
-
text_encoder: text_encoder
|
286 |
-
vision_encoder: vision_encoder
|
287 |
-
feature_extractor_vae: feature_extractor_vae
|
288 |
-
feature_extractor_clip: feature_extractor_clip
|
289 |
-
'''
|
290 |
-
self.register_to_config(ramping_coefficients=ramping_coefficients)
|
291 |
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
292 |
-
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
293 |
-
|
294 |
-
def prepare_extra_step_kwargs(self, generator, eta):
|
295 |
-
extra_step_kwargs = {}
|
296 |
-
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
297 |
-
if accepts_eta: extra_step_kwargs["eta"] = eta
|
298 |
-
|
299 |
-
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
300 |
-
if accepts_generator: extra_step_kwargs["generator"] = generator
|
301 |
-
return extra_step_kwargs
|
302 |
-
|
303 |
-
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
304 |
-
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
305 |
-
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
306 |
-
latents = latents * self.scheduler.init_noise_sigma
|
307 |
-
return latents
|
308 |
-
|
309 |
-
@torch.no_grad()
|
310 |
-
def _encode_prompt(
|
311 |
-
self,
|
312 |
-
prompt,
|
313 |
-
device,
|
314 |
-
num_images_per_prompt,
|
315 |
-
do_classifier_free_guidance,
|
316 |
-
negative_prompt=None,
|
317 |
-
prompt_embeds: Optional[torch.FloatTensor] = None,
|
318 |
-
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
319 |
-
lora_scale: Optional[float] = None,
|
320 |
-
):
|
321 |
-
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
322 |
-
self._lora_scale = lora_scale
|
323 |
-
|
324 |
-
if prompt is not None and isinstance(prompt, str):
|
325 |
-
batch_size = 1
|
326 |
-
elif prompt is not None and isinstance(prompt, list):
|
327 |
-
batch_size = len(prompt)
|
328 |
-
else:
|
329 |
-
batch_size = prompt_embeds.shape[0]
|
330 |
-
|
331 |
-
if prompt_embeds is None:
|
332 |
-
if isinstance(self, TextualInversionLoaderMixin):
|
333 |
-
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
334 |
-
|
335 |
-
text_inputs = self.tokenizer(
|
336 |
-
prompt,
|
337 |
-
padding="max_length",
|
338 |
-
max_length=self.tokenizer.model_max_length,
|
339 |
-
truncation=True,
|
340 |
-
return_tensors="pt",
|
341 |
-
)
|
342 |
-
text_input_ids = text_inputs.input_ids
|
343 |
-
|
344 |
-
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
345 |
-
attention_mask = text_inputs.attention_mask.to(device)
|
346 |
-
else:
|
347 |
-
attention_mask = None
|
348 |
-
|
349 |
-
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)[0]
|
350 |
-
|
351 |
-
if self.text_encoder is not None:
|
352 |
-
prompt_embeds_dtype = self.text_encoder.dtype
|
353 |
-
elif self.unet is not None:
|
354 |
-
prompt_embeds_dtype = self.unet.dtype
|
355 |
-
else:
|
356 |
-
prompt_embeds_dtype = prompt_embeds.dtype
|
357 |
-
|
358 |
-
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
359 |
-
bs_embed, seq_len, _ = prompt_embeds.shape
|
360 |
-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
361 |
-
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
362 |
-
|
363 |
-
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
364 |
-
uncond_tokens: List[str]
|
365 |
-
if negative_prompt is None: uncond_tokens = [""] * batch_size
|
366 |
-
elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError()
|
367 |
-
elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt]
|
368 |
-
elif batch_size != len(negative_prompt): raise ValueError()
|
369 |
-
else: uncond_tokens = negative_prompt
|
370 |
-
if isinstance(self, TextualInversionLoaderMixin):
|
371 |
-
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
372 |
-
|
373 |
-
max_length = prompt_embeds.shape[1]
|
374 |
-
uncond_input = self.tokenizer(uncond_tokens,
|
375 |
-
padding="max_length",
|
376 |
-
max_length=max_length,
|
377 |
-
truncation=True,
|
378 |
-
return_tensors="pt")
|
379 |
-
|
380 |
-
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
381 |
-
attention_mask = uncond_input.attention_mask.to(device)
|
382 |
-
else:
|
383 |
-
attention_mask = None
|
384 |
-
|
385 |
-
negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(device), attention_mask=attention_mask)
|
386 |
-
negative_prompt_embeds = negative_prompt_embeds[0]
|
387 |
-
|
388 |
-
if do_classifier_free_guidance:
|
389 |
-
seq_len = negative_prompt_embeds.shape[1]
|
390 |
-
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
391 |
-
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
392 |
-
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
393 |
-
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
394 |
-
|
395 |
-
return prompt_embeds
|
396 |
-
|
397 |
-
@torch.no_grad()
|
398 |
-
def encode_condition_image(self, image: torch.Tensor): return self.vae.encode(image).latent_dist.sample()
|
399 |
-
|
400 |
-
@torch.no_grad()
|
401 |
-
def __call__(self, image=None,
|
402 |
-
width=640,
|
403 |
-
height=960,
|
404 |
-
num_inference_steps=75,
|
405 |
-
return_dict=True,
|
406 |
-
generator=None,
|
407 |
-
**kwargs):
|
408 |
-
batch_size = 1
|
409 |
-
num_images_per_prompt = 1
|
410 |
-
output_type = 'pil'
|
411 |
-
do_classifier_free_guidance = True
|
412 |
-
guidance_rescale = 0.
|
413 |
-
if isinstance(self.unet, UNet2DConditionModel):
|
414 |
-
self.unet = RefOnlyNoisedUNet(self.unet, None, self.scheduler).eval()
|
415 |
-
|
416 |
-
cond_image = recenter_img(image)
|
417 |
-
cond_image = to_rgb_image(image)
|
418 |
-
image = cond_image
|
419 |
-
image_1 = self.feature_extractor_vae(images=image, return_tensors="pt").pixel_values
|
420 |
-
image_2 = self.feature_extractor_clip(images=image, return_tensors="pt").pixel_values
|
421 |
-
image_1 = image_1.to(device=self.vae.device, dtype=self.vae.dtype)
|
422 |
-
image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype)
|
423 |
-
|
424 |
-
cond_lat = self.encode_condition_image(image_1)
|
425 |
-
negative_lat = self.encode_condition_image(torch.zeros_like(image_1))
|
426 |
-
cond_lat = torch.cat([negative_lat, cond_lat])
|
427 |
-
cross_attention_kwargs = dict(cond_lat=cond_lat)
|
428 |
-
|
429 |
-
global_embeds = self.vision_encoder(image_2, output_hidden_states=False).image_embeds.unsqueeze(-2)
|
430 |
-
encoder_hidden_states = self._encode_prompt('', self.device, num_images_per_prompt, False)
|
431 |
-
ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
|
432 |
-
prompt_embeds = torch.cat([encoder_hidden_states, encoder_hidden_states + global_embeds * ramp])
|
433 |
-
|
434 |
-
device = self._execution_device
|
435 |
-
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
436 |
-
timesteps = self.scheduler.timesteps
|
437 |
-
num_channels_latents = self.unet.config.in_channels
|
438 |
-
latents = self.prepare_latents(batch_size * num_images_per_prompt,
|
439 |
-
num_channels_latents,
|
440 |
-
height,
|
441 |
-
width,
|
442 |
-
prompt_embeds.dtype,
|
443 |
-
device,
|
444 |
-
generator,
|
445 |
-
None)
|
446 |
-
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, 0.0)
|
447 |
-
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
448 |
-
|
449 |
-
# set adaptive cfg
|
450 |
-
# the image order is:
|
451 |
-
# [0, 60,
|
452 |
-
# 120, 180,
|
453 |
-
# 240, 300]
|
454 |
-
# the cfg is set as 3, 2.5, 2, 1.5
|
455 |
-
|
456 |
-
tmp_guidance_scale = torch.ones_like(latents)
|
457 |
-
tmp_guidance_scale[:, :, :40, :40] = 3
|
458 |
-
tmp_guidance_scale[:, :, :40, 40:] = 2.5
|
459 |
-
tmp_guidance_scale[:, :, 40:80, :40] = 2
|
460 |
-
tmp_guidance_scale[:, :, 40:80, 40:] = 1.5
|
461 |
-
tmp_guidance_scale[:, :, 80:120, :40] = 2
|
462 |
-
tmp_guidance_scale[:, :, 80:120, 40:] = 2.5
|
463 |
-
|
464 |
-
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
465 |
-
for i, t in enumerate(timesteps):
|
466 |
-
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
467 |
-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
468 |
-
|
469 |
-
noise_pred = self.unet(latent_model_input, t,
|
470 |
-
encoder_hidden_states=prompt_embeds,
|
471 |
-
cross_attention_kwargs=cross_attention_kwargs,
|
472 |
-
return_dict=False)[0]
|
473 |
-
|
474 |
-
adaptive_guidance_scale = (2 + 16 * (t / 1000) ** 5) / 3
|
475 |
-
if do_classifier_free_guidance:
|
476 |
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
477 |
-
noise_pred = noise_pred_uncond + \
|
478 |
-
tmp_guidance_scale * adaptive_guidance_scale * \
|
479 |
-
(noise_pred_text - noise_pred_uncond)
|
480 |
-
|
481 |
-
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
482 |
-
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
483 |
-
|
484 |
-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
485 |
-
if i==len(timesteps)-1 or ((i+1)>num_warmup_steps and (i+1)%self.scheduler.order==0):
|
486 |
-
progress_bar.update()
|
487 |
-
|
488 |
-
latents = unscale_latents(latents)
|
489 |
-
image = unscale_image(self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0])
|
490 |
-
image = self.image_processor.postprocess(image, output_type='pil')[0]
|
491 |
-
image = [image, cond_image]
|
492 |
-
return ImagePipelineOutput(images=image) if return_dict else (image,)
|
493 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mvd/hunyuan3d_mvd_std_pipeline.py
DELETED
@@ -1,471 +0,0 @@
|
|
1 |
-
# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
2 |
-
# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
3 |
-
|
4 |
-
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
5 |
-
# The below software and/or models in this distribution may have been
|
6 |
-
# modified by THL A29 Limited ("Tencent Modifications").
|
7 |
-
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
8 |
-
|
9 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
10 |
-
# except for the third-party components listed below.
|
11 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
12 |
-
# in the repsective licenses of these third-party components.
|
13 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
14 |
-
# components and must ensure that the usage of the third party components adheres to
|
15 |
-
# all relevant laws and regulations.
|
16 |
-
|
17 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
18 |
-
# their software and algorithms, including trained model weights, parameters (including
|
19 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
20 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
21 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
22 |
-
|
23 |
-
import inspect
|
24 |
-
from typing import Any, Dict, Optional
|
25 |
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
26 |
-
|
27 |
-
import os
|
28 |
-
import torch
|
29 |
-
import numpy as np
|
30 |
-
from PIL import Image
|
31 |
-
|
32 |
-
import diffusers
|
33 |
-
from diffusers.image_processor import VaeImageProcessor
|
34 |
-
from diffusers.utils.import_utils import is_xformers_available
|
35 |
-
from diffusers.schedulers import KarrasDiffusionSchedulers
|
36 |
-
from diffusers.utils.torch_utils import randn_tensor
|
37 |
-
from diffusers.utils.import_utils import is_xformers_available
|
38 |
-
from diffusers.models.attention_processor import (
|
39 |
-
Attention,
|
40 |
-
AttnProcessor,
|
41 |
-
XFormersAttnProcessor,
|
42 |
-
AttnProcessor2_0
|
43 |
-
)
|
44 |
-
from diffusers import (
|
45 |
-
AutoencoderKL,
|
46 |
-
DDPMScheduler,
|
47 |
-
DiffusionPipeline,
|
48 |
-
EulerAncestralDiscreteScheduler,
|
49 |
-
UNet2DConditionModel,
|
50 |
-
ImagePipelineOutput
|
51 |
-
)
|
52 |
-
import transformers
|
53 |
-
from transformers import (
|
54 |
-
CLIPImageProcessor,
|
55 |
-
CLIPTextModel,
|
56 |
-
CLIPTokenizer,
|
57 |
-
CLIPVisionModelWithProjection,
|
58 |
-
CLIPTextModelWithProjection
|
59 |
-
)
|
60 |
-
|
61 |
-
from .utils import to_rgb_image, white_out_background, recenter_img
|
62 |
-
|
63 |
-
EXAMPLE_DOC_STRING = """
|
64 |
-
Examples:
|
65 |
-
```py
|
66 |
-
>>> import torch
|
67 |
-
>>> from diffusers import Hunyuan3d_MVD_XL_Pipeline
|
68 |
-
|
69 |
-
>>> pipe = Hunyuan3d_MVD_XL_Pipeline.from_pretrained(
|
70 |
-
... "Tencent-Hunyuan-3D/MVD-XL", torch_dtype=torch.float16
|
71 |
-
... )
|
72 |
-
>>> pipe.to("cuda")
|
73 |
-
|
74 |
-
>>> img = Image.open("demo.png")
|
75 |
-
>>> res_img = pipe(img).images[0]
|
76 |
-
```
|
77 |
-
"""
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
def scale_latents(latents): return (latents - 0.22) * 0.75
|
82 |
-
def unscale_latents(latents): return (latents / 0.75) + 0.22
|
83 |
-
def scale_image(image): return (image - 0.5) / 0.5
|
84 |
-
def scale_image_2(image): return (image * 0.5) / 0.8
|
85 |
-
def unscale_image(image): return (image * 0.5) + 0.5
|
86 |
-
def unscale_image_2(image): return (image * 0.8) / 0.5
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
class ReferenceOnlyAttnProc(torch.nn.Module):
|
92 |
-
def __init__(self, chained_proc, enabled=False, name=None):
|
93 |
-
super().__init__()
|
94 |
-
self.enabled = enabled
|
95 |
-
self.chained_proc = chained_proc
|
96 |
-
self.name = name
|
97 |
-
|
98 |
-
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, mode="w", ref_dict=None):
|
99 |
-
encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states
|
100 |
-
if self.enabled:
|
101 |
-
if mode == 'w': ref_dict[self.name] = encoder_hidden_states
|
102 |
-
elif mode == 'r': encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1)
|
103 |
-
else: raise Exception(f"mode should not be {mode}")
|
104 |
-
return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask)
|
105 |
-
|
106 |
-
|
107 |
-
class RefOnlyNoisedUNet(torch.nn.Module):
|
108 |
-
def __init__(self, unet, scheduler) -> None:
|
109 |
-
super().__init__()
|
110 |
-
self.unet = unet
|
111 |
-
self.scheduler = scheduler
|
112 |
-
|
113 |
-
unet_attn_procs = dict()
|
114 |
-
for name, _ in unet.attn_processors.items():
|
115 |
-
if torch.__version__ >= '2.0': default_attn_proc = AttnProcessor2_0()
|
116 |
-
elif is_xformers_available(): default_attn_proc = XFormersAttnProcessor()
|
117 |
-
else: default_attn_proc = AttnProcessor()
|
118 |
-
unet_attn_procs[name] = ReferenceOnlyAttnProc(
|
119 |
-
default_attn_proc, enabled=name.endswith("attn1.processor"), name=name
|
120 |
-
)
|
121 |
-
unet.set_attn_processor(unet_attn_procs)
|
122 |
-
|
123 |
-
def __getattr__(self, name: str):
|
124 |
-
try:
|
125 |
-
return super().__getattr__(name)
|
126 |
-
except AttributeError:
|
127 |
-
return getattr(self.unet, name)
|
128 |
-
|
129 |
-
def forward(
|
130 |
-
self,
|
131 |
-
sample: torch.FloatTensor,
|
132 |
-
timestep: Union[torch.Tensor, float, int],
|
133 |
-
encoder_hidden_states: torch.Tensor,
|
134 |
-
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
135 |
-
class_labels: Optional[torch.Tensor] = None,
|
136 |
-
down_block_res_samples: Optional[Tuple[torch.Tensor]] = None,
|
137 |
-
mid_block_res_sample: Optional[Tuple[torch.Tensor]] = None,
|
138 |
-
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
139 |
-
return_dict: bool = True,
|
140 |
-
**kwargs
|
141 |
-
):
|
142 |
-
|
143 |
-
dtype = self.unet.dtype
|
144 |
-
|
145 |
-
# cond_lat add same level noise
|
146 |
-
cond_lat = cross_attention_kwargs['cond_lat']
|
147 |
-
noise = torch.randn_like(cond_lat)
|
148 |
-
|
149 |
-
noisy_cond_lat = self.scheduler.add_noise(cond_lat, noise, timestep.reshape(-1))
|
150 |
-
noisy_cond_lat = self.scheduler.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
|
151 |
-
|
152 |
-
ref_dict = {}
|
153 |
-
|
154 |
-
_ = self.unet(
|
155 |
-
noisy_cond_lat,
|
156 |
-
timestep,
|
157 |
-
encoder_hidden_states = encoder_hidden_states,
|
158 |
-
class_labels = class_labels,
|
159 |
-
cross_attention_kwargs = dict(mode="w", ref_dict=ref_dict),
|
160 |
-
added_cond_kwargs = added_cond_kwargs,
|
161 |
-
return_dict = return_dict,
|
162 |
-
**kwargs
|
163 |
-
)
|
164 |
-
|
165 |
-
res = self.unet(
|
166 |
-
sample,
|
167 |
-
timestep,
|
168 |
-
encoder_hidden_states,
|
169 |
-
class_labels=class_labels,
|
170 |
-
cross_attention_kwargs = dict(mode="r", ref_dict=ref_dict),
|
171 |
-
down_block_additional_residuals = [
|
172 |
-
sample.to(dtype=dtype) for sample in down_block_res_samples
|
173 |
-
] if down_block_res_samples is not None else None,
|
174 |
-
mid_block_additional_residual = (
|
175 |
-
mid_block_res_sample.to(dtype=dtype)
|
176 |
-
if mid_block_res_sample is not None else None),
|
177 |
-
added_cond_kwargs = added_cond_kwargs,
|
178 |
-
return_dict = return_dict,
|
179 |
-
**kwargs
|
180 |
-
)
|
181 |
-
return res
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
class HunYuan3D_MVD_Std_Pipeline(diffusers.DiffusionPipeline):
|
186 |
-
def __init__(
|
187 |
-
self,
|
188 |
-
vae: AutoencoderKL,
|
189 |
-
unet: UNet2DConditionModel,
|
190 |
-
scheduler: KarrasDiffusionSchedulers,
|
191 |
-
feature_extractor_vae: CLIPImageProcessor,
|
192 |
-
vision_processor: CLIPImageProcessor,
|
193 |
-
vision_encoder: CLIPVisionModelWithProjection,
|
194 |
-
vision_encoder_2: CLIPVisionModelWithProjection,
|
195 |
-
ramping_coefficients: Optional[list] = None,
|
196 |
-
add_watermarker: Optional[bool] = None,
|
197 |
-
safety_checker = None,
|
198 |
-
):
|
199 |
-
DiffusionPipeline.__init__(self)
|
200 |
-
|
201 |
-
self.register_modules(
|
202 |
-
vae=vae, unet=unet, scheduler=scheduler, safety_checker=None, feature_extractor_vae=feature_extractor_vae,
|
203 |
-
vision_processor=vision_processor, vision_encoder=vision_encoder, vision_encoder_2=vision_encoder_2,
|
204 |
-
)
|
205 |
-
self.register_to_config( ramping_coefficients = ramping_coefficients)
|
206 |
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
207 |
-
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
208 |
-
self.default_sample_size = self.unet.config.sample_size
|
209 |
-
self.watermark = None
|
210 |
-
self.prepare_init = False
|
211 |
-
|
212 |
-
def prepare(self):
|
213 |
-
assert isinstance(self.unet, UNet2DConditionModel), "unet should be UNet2DConditionModel"
|
214 |
-
self.unet = RefOnlyNoisedUNet(self.unet, self.scheduler).eval()
|
215 |
-
self.prepare_init = True
|
216 |
-
|
217 |
-
def encode_image(self, image: torch.Tensor, scale_factor: bool = False):
|
218 |
-
latent = self.vae.encode(image).latent_dist.sample()
|
219 |
-
return (latent * self.vae.config.scaling_factor) if scale_factor else latent
|
220 |
-
|
221 |
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
222 |
-
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
223 |
-
shape = (
|
224 |
-
batch_size,
|
225 |
-
num_channels_latents,
|
226 |
-
int(height) // self.vae_scale_factor,
|
227 |
-
int(width) // self.vae_scale_factor,
|
228 |
-
)
|
229 |
-
if isinstance(generator, list) and len(generator) != batch_size:
|
230 |
-
raise ValueError(
|
231 |
-
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
232 |
-
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
233 |
-
)
|
234 |
-
|
235 |
-
if latents is None:
|
236 |
-
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
237 |
-
else:
|
238 |
-
latents = latents.to(device)
|
239 |
-
|
240 |
-
# scale the initial noise by the standard deviation required by the scheduler
|
241 |
-
latents = latents * self.scheduler.init_noise_sigma
|
242 |
-
return latents
|
243 |
-
|
244 |
-
def _get_add_time_ids(
|
245 |
-
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
246 |
-
):
|
247 |
-
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
248 |
-
|
249 |
-
passed_add_embed_dim = (
|
250 |
-
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
251 |
-
)
|
252 |
-
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
253 |
-
|
254 |
-
if expected_add_embed_dim != passed_add_embed_dim:
|
255 |
-
raise ValueError(
|
256 |
-
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, " \
|
257 |
-
f"but a vector of {passed_add_embed_dim} was created. The model has an incorrect config." \
|
258 |
-
f" Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
259 |
-
)
|
260 |
-
|
261 |
-
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
262 |
-
return add_time_ids
|
263 |
-
|
264 |
-
def prepare_extra_step_kwargs(self, generator, eta):
|
265 |
-
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
266 |
-
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
267 |
-
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
268 |
-
# and should be between [0, 1]
|
269 |
-
|
270 |
-
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
271 |
-
extra_step_kwargs = {}
|
272 |
-
if accepts_eta: extra_step_kwargs["eta"] = eta
|
273 |
-
|
274 |
-
# check if the scheduler accepts generator
|
275 |
-
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
276 |
-
if accepts_generator: extra_step_kwargs["generator"] = generator
|
277 |
-
return extra_step_kwargs
|
278 |
-
|
279 |
-
@property
|
280 |
-
def guidance_scale(self):
|
281 |
-
return self._guidance_scale
|
282 |
-
|
283 |
-
@property
|
284 |
-
def interrupt(self):
|
285 |
-
return self._interrupt
|
286 |
-
|
287 |
-
@property
|
288 |
-
def do_classifier_free_guidance(self):
|
289 |
-
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
290 |
-
|
291 |
-
@torch.no_grad()
|
292 |
-
def __call__(
|
293 |
-
self,
|
294 |
-
image: Image.Image = None,
|
295 |
-
guidance_scale = 2.0,
|
296 |
-
output_type: Optional[str] = "pil",
|
297 |
-
num_inference_steps: int = 50,
|
298 |
-
return_dict: bool = True,
|
299 |
-
eta: float = 0.0,
|
300 |
-
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
301 |
-
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
302 |
-
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
303 |
-
latent: torch.Tensor = None,
|
304 |
-
guidance_curve = None,
|
305 |
-
**kwargs
|
306 |
-
):
|
307 |
-
if not self.prepare_init:
|
308 |
-
self.prepare()
|
309 |
-
|
310 |
-
here = dict(device=self.vae.device, dtype=self.vae.dtype)
|
311 |
-
|
312 |
-
batch_size = 1
|
313 |
-
num_images_per_prompt = 1
|
314 |
-
width, height = 512 * 2, 512 * 3
|
315 |
-
target_size = original_size = (height, width)
|
316 |
-
|
317 |
-
self._guidance_scale = guidance_scale
|
318 |
-
self._cross_attention_kwargs = cross_attention_kwargs
|
319 |
-
self._interrupt = False
|
320 |
-
|
321 |
-
device = self._execution_device
|
322 |
-
|
323 |
-
# Prepare timesteps
|
324 |
-
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
325 |
-
timesteps = self.scheduler.timesteps
|
326 |
-
|
327 |
-
# Prepare latent variables
|
328 |
-
num_channels_latents = self.unet.config.in_channels
|
329 |
-
latents = self.prepare_latents(
|
330 |
-
batch_size * num_images_per_prompt,
|
331 |
-
num_channels_latents,
|
332 |
-
height,
|
333 |
-
width,
|
334 |
-
self.vae.dtype,
|
335 |
-
device,
|
336 |
-
generator,
|
337 |
-
latents=latent,
|
338 |
-
)
|
339 |
-
|
340 |
-
# Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
341 |
-
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
342 |
-
|
343 |
-
|
344 |
-
# Prepare added time ids & embeddings
|
345 |
-
text_encoder_projection_dim = 1280
|
346 |
-
add_time_ids = self._get_add_time_ids(
|
347 |
-
original_size,
|
348 |
-
crops_coords_top_left,
|
349 |
-
target_size,
|
350 |
-
dtype=self.vae.dtype,
|
351 |
-
text_encoder_projection_dim=text_encoder_projection_dim,
|
352 |
-
)
|
353 |
-
negative_add_time_ids = add_time_ids
|
354 |
-
|
355 |
-
# hw: preprocess
|
356 |
-
cond_image = recenter_img(image)
|
357 |
-
cond_image = to_rgb_image(image)
|
358 |
-
image_vae = self.feature_extractor_vae(images=cond_image, return_tensors="pt").pixel_values.to(**here)
|
359 |
-
image_clip = self.vision_processor(images=cond_image, return_tensors="pt").pixel_values.to(**here)
|
360 |
-
|
361 |
-
# hw: get cond_lat from cond_img using vae
|
362 |
-
cond_lat = self.encode_image(image_vae, scale_factor=False)
|
363 |
-
negative_lat = self.encode_image(torch.zeros_like(image_vae), scale_factor=False)
|
364 |
-
cond_lat = torch.cat([negative_lat, cond_lat])
|
365 |
-
|
366 |
-
# hw: get visual global embedding using clip
|
367 |
-
global_embeds_1 = self.vision_encoder(image_clip, output_hidden_states=False).image_embeds.unsqueeze(-2)
|
368 |
-
global_embeds_2 = self.vision_encoder_2(image_clip, output_hidden_states=False).image_embeds.unsqueeze(-2)
|
369 |
-
global_embeds = torch.concat([global_embeds_1, global_embeds_2], dim=-1)
|
370 |
-
|
371 |
-
ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
|
372 |
-
prompt_embeds = self.uc_text_emb.to(**here)
|
373 |
-
pooled_prompt_embeds = self.uc_text_emb_2.to(**here)
|
374 |
-
|
375 |
-
prompt_embeds = prompt_embeds + global_embeds * ramp
|
376 |
-
add_text_embeds = pooled_prompt_embeds
|
377 |
-
|
378 |
-
if self.do_classifier_free_guidance:
|
379 |
-
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
380 |
-
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
381 |
-
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
382 |
-
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
383 |
-
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
384 |
-
|
385 |
-
prompt_embeds = prompt_embeds.to(device)
|
386 |
-
add_text_embeds = add_text_embeds.to(device)
|
387 |
-
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
388 |
-
|
389 |
-
# Denoising loop
|
390 |
-
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
391 |
-
timestep_cond = None
|
392 |
-
self._num_timesteps = len(timesteps)
|
393 |
-
|
394 |
-
if guidance_curve is None:
|
395 |
-
guidance_curve = lambda t: guidance_scale
|
396 |
-
|
397 |
-
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
398 |
-
for i, t in enumerate(timesteps):
|
399 |
-
if self.interrupt:
|
400 |
-
continue
|
401 |
-
|
402 |
-
# expand the latents if we are doing classifier free guidance
|
403 |
-
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
404 |
-
|
405 |
-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
406 |
-
|
407 |
-
# predict the noise residual
|
408 |
-
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
409 |
-
|
410 |
-
noise_pred = self.unet(
|
411 |
-
latent_model_input,
|
412 |
-
t,
|
413 |
-
encoder_hidden_states=prompt_embeds,
|
414 |
-
timestep_cond=timestep_cond,
|
415 |
-
cross_attention_kwargs=dict(cond_lat=cond_lat),
|
416 |
-
added_cond_kwargs=added_cond_kwargs,
|
417 |
-
return_dict=False,
|
418 |
-
)[0]
|
419 |
-
|
420 |
-
# perform guidance
|
421 |
-
|
422 |
-
# cur_guidance_scale = self.guidance_scale
|
423 |
-
cur_guidance_scale = guidance_curve(t) # 1.5 + 2.5 * ((t/1000)**2)
|
424 |
-
|
425 |
-
if self.do_classifier_free_guidance:
|
426 |
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
427 |
-
noise_pred = noise_pred_uncond + cur_guidance_scale * (noise_pred_text - noise_pred_uncond)
|
428 |
-
|
429 |
-
# cur_guidance_scale_topleft = (cur_guidance_scale - 1.0) * 4 + 1.0
|
430 |
-
# noise_pred_top_left = noise_pred_uncond +
|
431 |
-
# cur_guidance_scale_topleft * (noise_pred_text - noise_pred_uncond)
|
432 |
-
# _, _, h, w = noise_pred.shape
|
433 |
-
# noise_pred[:, :, :h//3, :w//2] = noise_pred_top_left[:, :, :h//3, :w//2]
|
434 |
-
|
435 |
-
# compute the previous noisy sample x_t -> x_t-1
|
436 |
-
latents_dtype = latents.dtype
|
437 |
-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
438 |
-
|
439 |
-
# call the callback, if provided
|
440 |
-
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
441 |
-
progress_bar.update()
|
442 |
-
|
443 |
-
latents = unscale_latents(latents)
|
444 |
-
|
445 |
-
if output_type=="latent":
|
446 |
-
image = latents
|
447 |
-
else:
|
448 |
-
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
449 |
-
image = unscale_image(unscale_image_2(image)).clamp(0, 1)
|
450 |
-
image = [
|
451 |
-
Image.fromarray((image[0]*255+0.5).clamp_(0, 255).permute(1, 2, 0).cpu().numpy().astype("uint8")),
|
452 |
-
# self.image_processor.postprocess(image, output_type=output_type)[0],
|
453 |
-
cond_image.resize((512, 512))
|
454 |
-
]
|
455 |
-
|
456 |
-
if not return_dict: return (image,)
|
457 |
-
return ImagePipelineOutput(images=image)
|
458 |
-
|
459 |
-
def save_pretrained(self, save_directory):
|
460 |
-
# uc_text_emb.pt and uc_text_emb_2.pt are inferenced and saved in advance
|
461 |
-
super().save_pretrained(save_directory)
|
462 |
-
torch.save(self.uc_text_emb, os.path.join(save_directory, "uc_text_emb.pt"))
|
463 |
-
torch.save(self.uc_text_emb_2, os.path.join(save_directory, "uc_text_emb_2.pt"))
|
464 |
-
|
465 |
-
@classmethod
|
466 |
-
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
467 |
-
# uc_text_emb.pt and uc_text_emb_2.pt are inferenced and saved in advance
|
468 |
-
pipeline = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
469 |
-
pipeline.uc_text_emb = torch.load(os.path.join(pretrained_model_name_or_path, "uc_text_emb.pt"))
|
470 |
-
pipeline.uc_text_emb_2 = torch.load(os.path.join(pretrained_model_name_or_path, "uc_text_emb_2.pt"))
|
471 |
-
return pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mvd/utils.py
DELETED
@@ -1,85 +0,0 @@
|
|
1 |
-
# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
2 |
-
# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
3 |
-
|
4 |
-
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
5 |
-
# The below software and/or models in this distribution may have been
|
6 |
-
# modified by THL A29 Limited ("Tencent Modifications").
|
7 |
-
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
8 |
-
|
9 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
10 |
-
# except for the third-party components listed below.
|
11 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
12 |
-
# in the repsective licenses of these third-party components.
|
13 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
14 |
-
# components and must ensure that the usage of the third party components adheres to
|
15 |
-
# all relevant laws and regulations.
|
16 |
-
|
17 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
18 |
-
# their software and algorithms, including trained model weights, parameters (including
|
19 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
20 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
21 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
22 |
-
|
23 |
-
import numpy as np
|
24 |
-
from PIL import Image
|
25 |
-
|
26 |
-
def to_rgb_image(maybe_rgba: Image.Image):
|
27 |
-
'''
|
28 |
-
convert a PIL.Image to rgb mode with white background
|
29 |
-
maybe_rgba: PIL.Image
|
30 |
-
return: PIL.Image
|
31 |
-
'''
|
32 |
-
if maybe_rgba.mode == 'RGB':
|
33 |
-
return maybe_rgba
|
34 |
-
elif maybe_rgba.mode == 'RGBA':
|
35 |
-
rgba = maybe_rgba
|
36 |
-
img = np.random.randint(255, 256, size=[rgba.size[1], rgba.size[0], 3], dtype=np.uint8)
|
37 |
-
img = Image.fromarray(img, 'RGB')
|
38 |
-
img.paste(rgba, mask=rgba.getchannel('A'))
|
39 |
-
return img
|
40 |
-
else:
|
41 |
-
raise ValueError("Unsupported image type.", maybe_rgba.mode)
|
42 |
-
|
43 |
-
def white_out_background(pil_img, is_gray_fg=True):
|
44 |
-
data = pil_img.getdata()
|
45 |
-
new_data = []
|
46 |
-
# convert fore-ground white to gray
|
47 |
-
for r, g, b, a in data:
|
48 |
-
if a < 16:
|
49 |
-
new_data.append((255, 255, 255, 0)) # back-ground to be black
|
50 |
-
else:
|
51 |
-
is_white = is_gray_fg and (r>235) and (g>235) and (b>235)
|
52 |
-
new_r = 235 if is_white else r
|
53 |
-
new_g = 235 if is_white else g
|
54 |
-
new_b = 235 if is_white else b
|
55 |
-
new_data.append((new_r, new_g, new_b, a))
|
56 |
-
pil_img.putdata(new_data)
|
57 |
-
return pil_img
|
58 |
-
|
59 |
-
def recenter_img(img, size=512, color=(255,255,255)):
|
60 |
-
img = white_out_background(img)
|
61 |
-
mask = np.array(img)[..., 3]
|
62 |
-
image = np.array(img)[..., :3]
|
63 |
-
|
64 |
-
H, W, C = image.shape
|
65 |
-
coords = np.nonzero(mask)
|
66 |
-
x_min, x_max = coords[0].min(), coords[0].max()
|
67 |
-
y_min, y_max = coords[1].min(), coords[1].max()
|
68 |
-
h = x_max - x_min
|
69 |
-
w = y_max - y_min
|
70 |
-
if h == 0 or w == 0: raise ValueError
|
71 |
-
roi = image[x_min:x_max, y_min:y_max]
|
72 |
-
|
73 |
-
border_ratio = 0.15 # 0.2
|
74 |
-
pad_h = int(h * border_ratio)
|
75 |
-
pad_w = int(w * border_ratio)
|
76 |
-
|
77 |
-
result_tmp = np.full((h + pad_h, w + pad_w, C), color, dtype=np.uint8)
|
78 |
-
result_tmp[pad_h // 2: pad_h // 2 + h, pad_w // 2: pad_w // 2 + w] = roi
|
79 |
-
|
80 |
-
cur_h, cur_w = result_tmp.shape[:2]
|
81 |
-
side = max(cur_h, cur_w)
|
82 |
-
result = np.full((side, side, C), color, dtype=np.uint8)
|
83 |
-
result[(side-cur_h)//2:(side-cur_h)//2+cur_h, (side-cur_w)//2:(side - cur_w)//2+cur_w,:] = result_tmp
|
84 |
-
result = Image.fromarray(result)
|
85 |
-
return result.resize((size, size), Image.LANCZOS) if size else result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
--find-links https://download.pytorch.org/whl/cu118
|
2 |
-
torch==2.2.0
|
3 |
-
torchvision==0.17.0
|
4 |
-
diffusers
|
5 |
-
transformers
|
6 |
-
rembg
|
7 |
-
tqdm
|
8 |
-
omegaconf
|
9 |
-
matplotlib
|
10 |
-
opencv-python
|
11 |
-
imageio
|
12 |
-
jaxtyping
|
13 |
-
einops
|
14 |
-
SentencePiece
|
15 |
-
accelerate
|
16 |
-
trimesh
|
17 |
-
PyMCubes
|
18 |
-
xatlas
|
19 |
-
libigl
|
20 |
-
git+https://github.com/facebookresearch/pytorch3d@stable
|
21 |
-
git+https://github.com/NVlabs/nvdiffrast
|
22 |
-
open3d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/image_to_3d.sh
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
# image to 3d
|
2 |
-
|
3 |
-
python main.py \
|
4 |
-
--image_prompt ./demos/example_000.png \
|
5 |
-
--save_folder ./outputs/test/ \
|
6 |
-
--max_faces_num 90000 \
|
7 |
-
--do_texture \
|
8 |
-
--do_render
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/image_to_3d_demo.sh
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
# image to 3d
|
2 |
-
|
3 |
-
python main.py \
|
4 |
-
--image_prompt ./demos/example_000.png \
|
5 |
-
--save_folder ./outputs/test/ \
|
6 |
-
--max_faces_num 90000 \
|
7 |
-
--do_texture_mapping \
|
8 |
-
--do_render
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/image_to_3d_fast.sh
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
# image to 3d fast
|
2 |
-
python main.py \
|
3 |
-
--image_prompt ./demos/example_000.png \
|
4 |
-
--save_folder ./outputs/test/ \
|
5 |
-
--max_faces_num 10000 \
|
6 |
-
--use_lite
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/image_to_3d_fast_demo.sh
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
# image to 3d fast
|
2 |
-
python main.py \
|
3 |
-
--image_prompt ./demos/example_000.png \
|
4 |
-
--save_folder ./outputs/test/ \
|
5 |
-
--max_faces_num 10000 \
|
6 |
-
--use_lite
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/text_to_3d.sh
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
# text to 3d fast
|
2 |
-
python main.py \
|
3 |
-
--text_prompt "a lovely cat" \
|
4 |
-
--save_folder ./outputs/test/ \
|
5 |
-
--max_faces_num 90000 \
|
6 |
-
--do_texture \
|
7 |
-
--do_render
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/text_to_3d_demo.sh
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
# text to 3d fast
|
2 |
-
python main.py \
|
3 |
-
--text_prompt "a lovely rabbit" \
|
4 |
-
--save_folder ./outputs/test/ \
|
5 |
-
--max_faces_num 90000 \
|
6 |
-
--do_texture_mapping \
|
7 |
-
--do_render
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/text_to_3d_fast.sh
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
# text to 3d fast
|
2 |
-
python main.py \
|
3 |
-
--text_prompt "一个广式茶杯" \
|
4 |
-
--save_folder ./outputs/test/ \
|
5 |
-
--max_faces_num 10000 \
|
6 |
-
--use_lite
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/text_to_3d_fast_demo.sh
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
# text to 3d fast
|
2 |
-
python main.py \
|
3 |
-
--text_prompt "一个广式茶杯" \
|
4 |
-
--save_folder ./outputs/test/ \
|
5 |
-
--max_faces_num 10000 \
|
6 |
-
--use_lite
|
|
|
|
|
|
|
|
|
|
|
|
|
|
svrm/.DS_Store
DELETED
Binary file (6.15 kB)
|
|
svrm/configs/2024-10-24T22-36-18-project.yaml
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_learning_rate: 3.0e-05
|
3 |
-
target: svrm.ldm.models.svrm.SVRMModel
|
4 |
-
params:
|
5 |
-
|
6 |
-
img_encoder_config:
|
7 |
-
target: svrm.ldm.modules.encoders.dinov2_mod.FrozenDinoV2ImageEmbedder
|
8 |
-
params:
|
9 |
-
version: dinov2_vitb14
|
10 |
-
|
11 |
-
img_to_triplane_config:
|
12 |
-
target: svrm.ldm.modules.translator.img_to_triplane.ImgToTriplaneModel
|
13 |
-
params:
|
14 |
-
pos_emb_size: 64
|
15 |
-
pos_emb_dim: 1024
|
16 |
-
cam_cond_dim: 20
|
17 |
-
n_heads: 16
|
18 |
-
d_head: 64
|
19 |
-
depth: 16
|
20 |
-
context_dim: 768
|
21 |
-
triplane_dim: 120
|
22 |
-
use_fp16: true
|
23 |
-
use_bf16: false
|
24 |
-
upsample_time: 2
|
25 |
-
|
26 |
-
render_config:
|
27 |
-
target: svrm.ldm.modules.rendering_neus.synthesizer.TriplaneSynthesizer
|
28 |
-
params:
|
29 |
-
triplane_dim: 120
|
30 |
-
samples_per_ray: 128
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
svrm/configs/svrm.yaml
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_learning_rate: 3.0e-05
|
3 |
-
target: svrm.ldm.models.svrm.SVRMModel
|
4 |
-
params:
|
5 |
-
|
6 |
-
img_encoder_config:
|
7 |
-
target: svrm.ldm.modules.encoders.dinov2_mod.FrozenDinoV2ImageEmbedder
|
8 |
-
params:
|
9 |
-
version: dinov2_vitb14
|
10 |
-
|
11 |
-
img_to_triplane_config:
|
12 |
-
target: svrm.ldm.modules.translator.img_to_triplane.ImgToTriplaneModel
|
13 |
-
params:
|
14 |
-
pos_emb_size: 64
|
15 |
-
pos_emb_dim: 1024
|
16 |
-
cam_cond_dim: 20
|
17 |
-
n_heads: 16
|
18 |
-
d_head: 64
|
19 |
-
depth: 16
|
20 |
-
context_dim: 768
|
21 |
-
triplane_dim: 120
|
22 |
-
use_fp16: true
|
23 |
-
use_bf16: false
|
24 |
-
upsample_time: 2
|
25 |
-
|
26 |
-
render_config:
|
27 |
-
target: svrm.ldm.modules.rendering_neus.synthesizer.TriplaneSynthesizer
|
28 |
-
params:
|
29 |
-
triplane_dim: 120
|
30 |
-
samples_per_ray: 128
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
svrm/ldm/.DS_Store
DELETED
Binary file (6.15 kB)
|
|
svrm/ldm/models/svrm.py
DELETED
@@ -1,263 +0,0 @@
|
|
1 |
-
# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
2 |
-
# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
3 |
-
|
4 |
-
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
5 |
-
# The below software and/or models in this distribution may have been
|
6 |
-
# modified by THL A29 Limited ("Tencent Modifications").
|
7 |
-
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
8 |
-
|
9 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
10 |
-
# except for the third-party components listed below.
|
11 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
12 |
-
# in the repsective licenses of these third-party components.
|
13 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
14 |
-
# components and must ensure that the usage of the third party components adheres to
|
15 |
-
# all relevant laws and regulations.
|
16 |
-
|
17 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
18 |
-
# their software and algorithms, including trained model weights, parameters (including
|
19 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
20 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
21 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
22 |
-
|
23 |
-
import os
|
24 |
-
import time
|
25 |
-
import math
|
26 |
-
import cv2
|
27 |
-
import numpy as np
|
28 |
-
import itertools
|
29 |
-
import shutil
|
30 |
-
from tqdm import tqdm
|
31 |
-
import torch
|
32 |
-
import torch.nn.functional as F
|
33 |
-
from einops import rearrange
|
34 |
-
try:
|
35 |
-
import trimesh
|
36 |
-
import mcubes
|
37 |
-
import xatlas
|
38 |
-
import open3d as o3d
|
39 |
-
except:
|
40 |
-
raise "failed to import 3d libraries "
|
41 |
-
|
42 |
-
from ..modules.rendering_neus.mesh import Mesh
|
43 |
-
from ..modules.rendering_neus.rasterize import NVDiffRasterizerContext
|
44 |
-
|
45 |
-
from ..utils.ops import scale_tensor
|
46 |
-
from ..util import count_params, instantiate_from_config
|
47 |
-
from ..vis_util import render
|
48 |
-
|
49 |
-
|
50 |
-
def unwrap_uv(v_pos, t_pos_idx):
|
51 |
-
print("Using xatlas to perform UV unwrapping, may take a while ...")
|
52 |
-
atlas = xatlas.Atlas()
|
53 |
-
atlas.add_mesh(v_pos, t_pos_idx)
|
54 |
-
atlas.generate(xatlas.ChartOptions(), xatlas.PackOptions())
|
55 |
-
_, indices, uvs = atlas.get_mesh(0)
|
56 |
-
indices = indices.astype(np.int64, casting="same_kind")
|
57 |
-
return uvs, indices
|
58 |
-
|
59 |
-
|
60 |
-
def uv_padding(image, hole_mask, uv_padding_size = 2):
|
61 |
-
return cv2.inpaint(
|
62 |
-
(image.detach().cpu().numpy() * 255).astype(np.uint8),
|
63 |
-
(hole_mask.detach().cpu().numpy() * 255).astype(np.uint8),
|
64 |
-
uv_padding_size,
|
65 |
-
cv2.INPAINT_TELEA
|
66 |
-
)
|
67 |
-
|
68 |
-
def refine_mesh(vtx_refine, faces_refine):
|
69 |
-
mesh = o3d.geometry.TriangleMesh(
|
70 |
-
vertices=o3d.utility.Vector3dVector(vtx_refine),
|
71 |
-
triangles=o3d.utility.Vector3iVector(faces_refine))
|
72 |
-
|
73 |
-
mesh = mesh.remove_unreferenced_vertices()
|
74 |
-
mesh = mesh.remove_duplicated_triangles()
|
75 |
-
mesh = mesh.remove_duplicated_vertices()
|
76 |
-
|
77 |
-
voxel_size = max(mesh.get_max_bound() - mesh.get_min_bound())
|
78 |
-
|
79 |
-
mesh = mesh.simplify_vertex_clustering(
|
80 |
-
voxel_size=0.007, # 0.005
|
81 |
-
contraction=o3d.geometry.SimplificationContraction.Average)
|
82 |
-
|
83 |
-
mesh = mesh.filter_smooth_simple(number_of_iterations=2)
|
84 |
-
|
85 |
-
vtx_refine = np.asarray(mesh.vertices).astype(np.float32)
|
86 |
-
faces_refine = np.asarray(mesh.triangles)
|
87 |
-
return vtx_refine, faces_refine, mesh
|
88 |
-
|
89 |
-
|
90 |
-
class SVRMModel(torch.nn.Module):
|
91 |
-
def __init__(
|
92 |
-
self,
|
93 |
-
img_encoder_config,
|
94 |
-
img_to_triplane_config,
|
95 |
-
render_config,
|
96 |
-
device = "cuda:0",
|
97 |
-
**kwargs
|
98 |
-
):
|
99 |
-
super().__init__()
|
100 |
-
|
101 |
-
self.img_encoder = instantiate_from_config(img_encoder_config).half()
|
102 |
-
self.img_to_triplane_decoder = instantiate_from_config(img_to_triplane_config).half()
|
103 |
-
self.render = instantiate_from_config(render_config).half()
|
104 |
-
self.device = device
|
105 |
-
count_params(self, verbose=True)
|
106 |
-
|
107 |
-
@torch.no_grad()
|
108 |
-
def export_mesh_with_uv(
|
109 |
-
self,
|
110 |
-
data,
|
111 |
-
mesh_size: int = 384,
|
112 |
-
ctx = None,
|
113 |
-
context_type = 'cuda',
|
114 |
-
texture_res = 1024,
|
115 |
-
target_face_count = 10000,
|
116 |
-
do_texture_mapping = True,
|
117 |
-
out_dir = 'outputs/test'
|
118 |
-
):
|
119 |
-
"""
|
120 |
-
color_type: 0 for ray texture, 1 for vertices texture
|
121 |
-
"""
|
122 |
-
st = time.time()
|
123 |
-
here = {'device': self.device, 'dtype': torch.float16}
|
124 |
-
input_view_image = data["input_view"].to(**here) # [b, m, c, h, w]
|
125 |
-
input_view_cam = data["input_view_cam"].to(**here) # [b, m, 20]
|
126 |
-
|
127 |
-
batch_size, input_view_num, *_ = input_view_image.shape
|
128 |
-
assert batch_size == 1, "batch size should be 1"
|
129 |
-
|
130 |
-
input_view_image = rearrange(input_view_image, 'b m c h w -> (b m) c h w')
|
131 |
-
input_view_cam = rearrange(input_view_cam, 'b m d -> (b m) d')
|
132 |
-
input_view_feat = self.img_encoder(input_view_image, input_view_cam)
|
133 |
-
input_view_feat = rearrange(input_view_feat, '(b m) l d -> b (l m) d', m=input_view_num)
|
134 |
-
|
135 |
-
# -- decoder
|
136 |
-
torch.cuda.empty_cache()
|
137 |
-
triplane_gen = self.img_to_triplane_decoder(input_view_feat) # [b, 3, tri_dim, h, w]
|
138 |
-
del input_view_feat
|
139 |
-
torch.cuda.empty_cache()
|
140 |
-
|
141 |
-
# --- triplane nerf render
|
142 |
-
|
143 |
-
cur_triplane = triplane_gen[0:1]
|
144 |
-
|
145 |
-
aabb = torch.tensor([[-0.6, -0.6, -0.6], [0.6, 0.6, 0.6]]).unsqueeze(0).to(**here)
|
146 |
-
grid_out = self.render.forward_grid(planes=cur_triplane, grid_size=mesh_size, aabb=aabb)
|
147 |
-
|
148 |
-
print(f"=====> LRM forward time: {time.time() - st}")
|
149 |
-
st = time.time()
|
150 |
-
|
151 |
-
vtx, faces = mcubes.marching_cubes(0. - grid_out['sdf'].squeeze(0).squeeze(-1).cpu().float().numpy(), 0)
|
152 |
-
|
153 |
-
bbox = aabb[0].cpu().numpy()
|
154 |
-
vtx = vtx / (mesh_size - 1)
|
155 |
-
vtx = vtx * (bbox[1] - bbox[0]) + bbox[0]
|
156 |
-
|
157 |
-
# refine mesh
|
158 |
-
vtx_refine, faces_refine, mesh = refine_mesh(vtx, faces)
|
159 |
-
|
160 |
-
# reduce faces
|
161 |
-
if faces_refine.shape[0] > target_face_count:
|
162 |
-
print(f"reduce face: {faces_refine.shape[0]} -> {target_face_count}")
|
163 |
-
mesh = o3d.geometry.TriangleMesh(
|
164 |
-
vertices = o3d.utility.Vector3dVector(vtx_refine),
|
165 |
-
triangles = o3d.utility.Vector3iVector(faces_refine)
|
166 |
-
)
|
167 |
-
|
168 |
-
# Function to simplify mesh using Quadric Error Metric Decimation by Garland and Heckbert
|
169 |
-
mesh = mesh.simplify_quadric_decimation(target_face_count, boundary_weight=1.0)
|
170 |
-
|
171 |
-
mesh = Mesh(
|
172 |
-
v_pos = torch.from_numpy(np.asarray(mesh.vertices)).to(self.device),
|
173 |
-
t_pos_idx = torch.from_numpy(np.asarray(mesh.triangles)).to(self.device),
|
174 |
-
v_rgb = torch.from_numpy(np.asarray(mesh.vertex_colors)).to(self.device)
|
175 |
-
)
|
176 |
-
vtx_refine = mesh.v_pos.cpu().numpy()
|
177 |
-
faces_refine = mesh.t_pos_idx.cpu().numpy()
|
178 |
-
|
179 |
-
vtx_colors = self.render.forward_points(cur_triplane, torch.tensor(vtx_refine).unsqueeze(0).to(**here))
|
180 |
-
vtx_colors = vtx_colors['rgb'].float().squeeze(0).cpu().numpy()
|
181 |
-
|
182 |
-
color_ratio = 0.8 # increase brightness
|
183 |
-
with open(f'{out_dir}/mesh_with_colors.obj', 'w') as fid:
|
184 |
-
verts = vtx_refine[:, [1,2,0]]
|
185 |
-
for pidx, pp in enumerate(verts):
|
186 |
-
color = vtx_colors[pidx]
|
187 |
-
color = [color[0]**color_ratio, color[1]**color_ratio, color[2]**color_ratio]
|
188 |
-
fid.write('v %f %f %f %f %f %f\n' % (pp[0], pp[1], pp[2], color[0], color[1], color[2]))
|
189 |
-
for i, f in enumerate(faces_refine):
|
190 |
-
f1 = f + 1
|
191 |
-
fid.write('f %d %d %d\n' % (f1[0], f1[1], f1[2]))
|
192 |
-
|
193 |
-
mesh = trimesh.load_mesh(f'{out_dir}/mesh_with_colors.obj')
|
194 |
-
print(f"=====> generate mesh with vertex shading time: {time.time() - st}")
|
195 |
-
st = time.time()
|
196 |
-
|
197 |
-
if not do_texture_mapping:
|
198 |
-
shutil.copy(f'{out_dir}/mesh_with_colors.obj', f'{out_dir}/mesh.obj')
|
199 |
-
mesh.export(f'{out_dir}/mesh.glb', file_type='glb')
|
200 |
-
return None
|
201 |
-
|
202 |
-
########## export texture ########
|
203 |
-
st = time.time()
|
204 |
-
|
205 |
-
# uv unwrap
|
206 |
-
vtx_tex, t_tex_idx = unwrap_uv(vtx_refine, faces_refine)
|
207 |
-
vtx_refine = torch.from_numpy(vtx_refine).to(self.device)
|
208 |
-
faces_refine = torch.from_numpy(faces_refine).to(self.device)
|
209 |
-
t_tex_idx = torch.from_numpy(t_tex_idx).to(self.device)
|
210 |
-
uv_clip = torch.from_numpy(vtx_tex * 2.0 - 1.0).to(self.device)
|
211 |
-
|
212 |
-
# rasterize
|
213 |
-
ctx = NVDiffRasterizerContext(context_type, cur_triplane.device) if ctx is None else ctx
|
214 |
-
rast = ctx.rasterize_one(
|
215 |
-
torch.cat([
|
216 |
-
uv_clip,
|
217 |
-
torch.zeros_like(uv_clip[..., 0:1]),
|
218 |
-
torch.ones_like(uv_clip[..., 0:1])
|
219 |
-
], dim=-1),
|
220 |
-
t_tex_idx,
|
221 |
-
(texture_res, texture_res)
|
222 |
-
)[0]
|
223 |
-
hole_mask = ~(rast[:, :, 3] > 0)
|
224 |
-
|
225 |
-
# Interpolate world space position
|
226 |
-
gb_pos = ctx.interpolate_one(vtx_refine, rast[None, ...], faces_refine)[0][0]
|
227 |
-
with torch.no_grad():
|
228 |
-
gb_mask_pos_scale = scale_tensor(gb_pos.unsqueeze(0).view(1, -1, 3), (-1, 1), (-1, 1))
|
229 |
-
tex_map = self.render.forward_points(cur_triplane, gb_mask_pos_scale)['rgb']
|
230 |
-
tex_map = tex_map.float().squeeze(0) # (0, 1)
|
231 |
-
tex_map = tex_map.view((texture_res, texture_res, 3))
|
232 |
-
img = uv_padding(tex_map, hole_mask)
|
233 |
-
img = ((img/255.0) ** color_ratio) * 255 # increase brightness
|
234 |
-
img = img.clip(0, 255).astype(np.uint8)
|
235 |
-
|
236 |
-
verts = vtx_refine.cpu().numpy()[:, [1,2,0]]
|
237 |
-
faces = faces_refine.cpu().numpy()
|
238 |
-
|
239 |
-
with open(f'{out_dir}/texture.mtl', 'w') as fid:
|
240 |
-
fid.write('newmtl material_0\n')
|
241 |
-
fid.write("Ka 1.000 1.000 1.000\n")
|
242 |
-
fid.write("Kd 1.000 1.000 1.000\n")
|
243 |
-
fid.write("Ks 0.000 0.000 0.000\n")
|
244 |
-
fid.write("d 1.0\n")
|
245 |
-
fid.write("illum 2\n")
|
246 |
-
fid.write(f'map_Kd texture.png\n')
|
247 |
-
|
248 |
-
with open(f'{out_dir}/mesh.obj', 'w') as fid:
|
249 |
-
fid.write(f'mtllib texture.mtl\n')
|
250 |
-
for pidx, pp in enumerate(verts):
|
251 |
-
fid.write('v %f %f %f\n' % (pp[0], pp[1], pp[2]))
|
252 |
-
for pidx, pp in enumerate(vtx_tex):
|
253 |
-
fid.write('vt %f %f\n' % (pp[0], 1 - pp[1]))
|
254 |
-
fid.write('usemtl material_0\n')
|
255 |
-
for i, f in enumerate(faces):
|
256 |
-
f1 = f + 1
|
257 |
-
f2 = t_tex_idx[i] + 1
|
258 |
-
fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2],))
|
259 |
-
|
260 |
-
cv2.imwrite(f'{out_dir}/texture.png', img[..., [2, 1, 0]])
|
261 |
-
mesh = trimesh.load_mesh(f'{out_dir}/mesh.obj')
|
262 |
-
mesh.export(f'{out_dir}/mesh.glb', file_type='glb')
|
263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
svrm/ldm/modules/attention.py
DELETED
@@ -1,457 +0,0 @@
|
|
1 |
-
from inspect import isfunction
|
2 |
-
import math
|
3 |
-
import torch
|
4 |
-
import torch.nn.functional as F
|
5 |
-
from torch import nn, einsum
|
6 |
-
from einops import rearrange, repeat
|
7 |
-
import numpy as np
|
8 |
-
|
9 |
-
FLASH_IS_AVAILABLE = XFORMERS_IS_AVAILBLE = False
|
10 |
-
try:
|
11 |
-
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
|
12 |
-
FLASH_IS_AVAILABLE = True
|
13 |
-
except:
|
14 |
-
try:
|
15 |
-
import xformers
|
16 |
-
import xformers.ops
|
17 |
-
XFORMERS_IS_AVAILBLE = True
|
18 |
-
except:
|
19 |
-
pass
|
20 |
-
|
21 |
-
def exists(val):
|
22 |
-
return val is not None
|
23 |
-
|
24 |
-
|
25 |
-
def uniq(arr):
|
26 |
-
return{el: True for el in arr}.keys()
|
27 |
-
|
28 |
-
|
29 |
-
def default(val, d):
|
30 |
-
if exists(val):
|
31 |
-
return val
|
32 |
-
return d() if isfunction(d) else d
|
33 |
-
|
34 |
-
|
35 |
-
def max_neg_value(t):
|
36 |
-
return -torch.finfo(t.dtype).max
|
37 |
-
|
38 |
-
|
39 |
-
def init_(tensor):
|
40 |
-
dim = tensor.shape[-1]
|
41 |
-
std = 1 / math.sqrt(dim)
|
42 |
-
tensor.uniform_(-std, std)
|
43 |
-
return tensor
|
44 |
-
|
45 |
-
def checkpoint(func, inputs, params, flag):
|
46 |
-
"""
|
47 |
-
Evaluate a function without caching intermediate activations, allowing for
|
48 |
-
reduced memory at the expense of extra compute in the backward pass.
|
49 |
-
:param func: the function to evaluate.
|
50 |
-
:param inputs: the argument sequence to pass to `func`.
|
51 |
-
:param params: a sequence of parameters `func` depends on but does not
|
52 |
-
explicitly take as arguments.
|
53 |
-
:param flag: if False, disable gradient checkpointing.
|
54 |
-
"""
|
55 |
-
if flag:
|
56 |
-
args = tuple(inputs) + tuple(params)
|
57 |
-
return CheckpointFunction.apply(func, len(inputs), *args)
|
58 |
-
else:
|
59 |
-
return func(*inputs)
|
60 |
-
|
61 |
-
|
62 |
-
class CheckpointFunction(torch.autograd.Function):
|
63 |
-
@staticmethod
|
64 |
-
def forward(ctx, run_function, length, *args):
|
65 |
-
ctx.run_function = run_function
|
66 |
-
ctx.input_tensors = list(args[:length])
|
67 |
-
ctx.input_params = list(args[length:])
|
68 |
-
|
69 |
-
with torch.no_grad():
|
70 |
-
output_tensors = ctx.run_function(*ctx.input_tensors)
|
71 |
-
return output_tensors
|
72 |
-
|
73 |
-
@staticmethod
|
74 |
-
def backward(ctx, *output_grads):
|
75 |
-
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
76 |
-
with torch.enable_grad():
|
77 |
-
# Fixes a bug where the first op in run_function modifies the
|
78 |
-
# Tensor storage in place, which is not allowed for detach()'d
|
79 |
-
# Tensors.
|
80 |
-
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
81 |
-
output_tensors = ctx.run_function(*shallow_copies)
|
82 |
-
input_grads = torch.autograd.grad(
|
83 |
-
output_tensors,
|
84 |
-
ctx.input_tensors + ctx.input_params,
|
85 |
-
output_grads,
|
86 |
-
allow_unused=True,
|
87 |
-
)
|
88 |
-
del ctx.input_tensors
|
89 |
-
del ctx.input_params
|
90 |
-
del output_tensors
|
91 |
-
return (None, None) + input_grads
|
92 |
-
|
93 |
-
|
94 |
-
# feedforward
|
95 |
-
class GEGLU(nn.Module):
|
96 |
-
def __init__(self, dim_in, dim_out):
|
97 |
-
super().__init__()
|
98 |
-
self.proj = nn.Linear(dim_in, dim_out * 2)
|
99 |
-
|
100 |
-
def forward(self, x):
|
101 |
-
x, gate = self.proj(x).chunk(2, dim=-1)
|
102 |
-
return x * F.gelu(gate)
|
103 |
-
|
104 |
-
|
105 |
-
class FeedForward(nn.Module):
|
106 |
-
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
107 |
-
super().__init__()
|
108 |
-
inner_dim = int(dim * mult)
|
109 |
-
dim_out = default(dim_out, dim)
|
110 |
-
project_in = nn.Sequential(
|
111 |
-
nn.Linear(dim, inner_dim),
|
112 |
-
nn.GELU()
|
113 |
-
) if not glu else GEGLU(dim, inner_dim)
|
114 |
-
|
115 |
-
self.net = nn.Sequential(
|
116 |
-
project_in,
|
117 |
-
nn.Dropout(dropout),
|
118 |
-
nn.Linear(inner_dim, dim_out)
|
119 |
-
)
|
120 |
-
|
121 |
-
def forward(self, x):
|
122 |
-
return self.net(x)
|
123 |
-
|
124 |
-
|
125 |
-
def zero_module(module):
|
126 |
-
"""
|
127 |
-
Zero out the parameters of a module and return it.
|
128 |
-
"""
|
129 |
-
for p in module.parameters():
|
130 |
-
p.detach().zero_()
|
131 |
-
return module
|
132 |
-
|
133 |
-
|
134 |
-
def Normalize(in_channels):
|
135 |
-
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
136 |
-
|
137 |
-
|
138 |
-
class LinearAttention(nn.Module):
|
139 |
-
def __init__(self, dim, heads=4, dim_head=32):
|
140 |
-
super().__init__()
|
141 |
-
self.heads = heads
|
142 |
-
hidden_dim = dim_head * heads
|
143 |
-
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
144 |
-
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
145 |
-
|
146 |
-
def forward(self, x):
|
147 |
-
b, c, h, w = x.shape
|
148 |
-
qkv = self.to_qkv(x)
|
149 |
-
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
150 |
-
k = k.softmax(dim=-1)
|
151 |
-
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
152 |
-
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
153 |
-
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
154 |
-
return self.to_out(out)
|
155 |
-
|
156 |
-
|
157 |
-
class SpatialSelfAttention(nn.Module):
|
158 |
-
def __init__(self, in_channels):
|
159 |
-
super().__init__()
|
160 |
-
self.in_channels = in_channels
|
161 |
-
|
162 |
-
self.norm = Normalize(in_channels)
|
163 |
-
self.q = torch.nn.Conv2d(in_channels,
|
164 |
-
in_channels,
|
165 |
-
kernel_size=1,
|
166 |
-
stride=1,
|
167 |
-
padding=0)
|
168 |
-
self.k = torch.nn.Conv2d(in_channels,
|
169 |
-
in_channels,
|
170 |
-
kernel_size=1,
|
171 |
-
stride=1,
|
172 |
-
padding=0)
|
173 |
-
self.v = torch.nn.Conv2d(in_channels,
|
174 |
-
in_channels,
|
175 |
-
kernel_size=1,
|
176 |
-
stride=1,
|
177 |
-
padding=0)
|
178 |
-
self.proj_out = torch.nn.Conv2d(in_channels,
|
179 |
-
in_channels,
|
180 |
-
kernel_size=1,
|
181 |
-
stride=1,
|
182 |
-
padding=0)
|
183 |
-
|
184 |
-
def forward(self, x):
|
185 |
-
h_ = x
|
186 |
-
h_ = self.norm(h_)
|
187 |
-
q = self.q(h_)
|
188 |
-
k = self.k(h_)
|
189 |
-
v = self.v(h_)
|
190 |
-
|
191 |
-
# compute attention
|
192 |
-
b,c,h,w = q.shape
|
193 |
-
q = rearrange(q, 'b c h w -> b (h w) c')
|
194 |
-
k = rearrange(k, 'b c h w -> b c (h w)')
|
195 |
-
w_ = torch.einsum('bij,bjk->bik', q, k)
|
196 |
-
|
197 |
-
w_ = w_ * (int(c)**(-0.5))
|
198 |
-
w_ = torch.nn.functional.softmax(w_, dim=2)
|
199 |
-
|
200 |
-
# attend to values
|
201 |
-
v = rearrange(v, 'b c h w -> b c (h w)')
|
202 |
-
w_ = rearrange(w_, 'b i j -> b j i')
|
203 |
-
h_ = torch.einsum('bij,bjk->bik', v, w_)
|
204 |
-
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
205 |
-
h_ = self.proj_out(h_)
|
206 |
-
|
207 |
-
return x+h_
|
208 |
-
|
209 |
-
|
210 |
-
class CrossAttention(nn.Module):
|
211 |
-
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
212 |
-
super().__init__()
|
213 |
-
inner_dim = dim_head * heads
|
214 |
-
context_dim = default(context_dim, query_dim)
|
215 |
-
self.scale = dim_head ** -0.5
|
216 |
-
self.heads = heads
|
217 |
-
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
218 |
-
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
219 |
-
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
220 |
-
|
221 |
-
self.to_out = nn.Sequential(
|
222 |
-
nn.Linear(inner_dim, query_dim),
|
223 |
-
nn.Dropout(dropout)
|
224 |
-
)
|
225 |
-
|
226 |
-
def forward(self, x, context=None, mask=None):
|
227 |
-
h = self.heads
|
228 |
-
q = self.to_q(x)
|
229 |
-
context = default(context, x)
|
230 |
-
k = self.to_k(context)
|
231 |
-
v = self.to_v(context)
|
232 |
-
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
233 |
-
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
234 |
-
if exists(mask):
|
235 |
-
mask = rearrange(mask, 'b ... -> b (...)')
|
236 |
-
max_neg_value = -torch.finfo(sim.dtype).max
|
237 |
-
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
238 |
-
sim.masked_fill_(~mask, max_neg_value)
|
239 |
-
# attention, what we cannot get enough of
|
240 |
-
attn = sim.softmax(dim=-1)
|
241 |
-
out = einsum('b i j, b j d -> b i d', attn, v) # [b*h, n, d]
|
242 |
-
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
243 |
-
return self.to_out(out)
|
244 |
-
|
245 |
-
|
246 |
-
class FlashAttention(nn.Module):
|
247 |
-
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
248 |
-
super().__init__()
|
249 |
-
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
250 |
-
f"{heads} heads.")
|
251 |
-
inner_dim = dim_head * heads
|
252 |
-
context_dim = default(context_dim, query_dim)
|
253 |
-
self.scale = dim_head ** -0.5
|
254 |
-
self.heads = heads
|
255 |
-
self.dropout = dropout
|
256 |
-
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
257 |
-
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
258 |
-
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
259 |
-
self.to_out = nn.Sequential(
|
260 |
-
nn.Linear(inner_dim, query_dim),
|
261 |
-
nn.Dropout(dropout)
|
262 |
-
)
|
263 |
-
|
264 |
-
def forward(self, x, context=None, mask=None):
|
265 |
-
context = default(context, x)
|
266 |
-
h = self.heads
|
267 |
-
dtype = torch.bfloat16 # torch.half
|
268 |
-
q = self.to_q(x).to(dtype)
|
269 |
-
k = self.to_k(context).to(dtype)
|
270 |
-
v = self.to_v(context).to(dtype)
|
271 |
-
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q, k, v)) # q is [b, 3079, 16, 64]
|
272 |
-
out = flash_attn_func(q, k, v, dropout_p=self.dropout, softmax_scale=None, causal=False, window_size=(-1, -1)) # out is same shape to q
|
273 |
-
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
274 |
-
return self.to_out(out.float())
|
275 |
-
|
276 |
-
class MemoryEfficientCrossAttention(nn.Module):
|
277 |
-
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
278 |
-
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
279 |
-
super().__init__()
|
280 |
-
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
281 |
-
f"{heads} heads.")
|
282 |
-
inner_dim = dim_head * heads
|
283 |
-
context_dim = default(context_dim, query_dim)
|
284 |
-
|
285 |
-
self.heads = heads
|
286 |
-
self.dim_head = dim_head
|
287 |
-
|
288 |
-
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
289 |
-
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
290 |
-
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
291 |
-
|
292 |
-
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
293 |
-
self.attention_op: Optional[Any] = None
|
294 |
-
|
295 |
-
def forward(self, x, context=None, mask=None):
|
296 |
-
q = self.to_q(x)
|
297 |
-
context = default(context, x)
|
298 |
-
k = self.to_k(context)
|
299 |
-
v = self.to_v(context)
|
300 |
-
|
301 |
-
b, _, _ = q.shape
|
302 |
-
q, k, v = map(
|
303 |
-
lambda t: t.unsqueeze(3)
|
304 |
-
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
305 |
-
.permute(0, 2, 1, 3)
|
306 |
-
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
307 |
-
.contiguous(),
|
308 |
-
(q, k, v),
|
309 |
-
)
|
310 |
-
|
311 |
-
# actually compute the attention, what we cannot get enough of
|
312 |
-
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
313 |
-
|
314 |
-
if exists(mask):
|
315 |
-
raise NotImplementedError
|
316 |
-
out = (
|
317 |
-
out.unsqueeze(0)
|
318 |
-
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
319 |
-
.permute(0, 2, 1, 3)
|
320 |
-
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
321 |
-
)
|
322 |
-
return self.to_out(out)
|
323 |
-
|
324 |
-
class BasicTransformerBlock(nn.Module):
|
325 |
-
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
326 |
-
disable_self_attn=False):
|
327 |
-
super().__init__()
|
328 |
-
self.disable_self_attn = disable_self_attn
|
329 |
-
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
330 |
-
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
|
331 |
-
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
332 |
-
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
333 |
-
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
334 |
-
self.norm1 = Fp32LayerNorm(dim)
|
335 |
-
self.norm2 = Fp32LayerNorm(dim)
|
336 |
-
self.norm3 = Fp32LayerNorm(dim)
|
337 |
-
self.checkpoint = checkpoint
|
338 |
-
|
339 |
-
def forward(self, x, context=None):
|
340 |
-
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
341 |
-
|
342 |
-
def _forward(self, x, context=None):
|
343 |
-
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
|
344 |
-
x = self.attn2(self.norm2(x), context=context) + x
|
345 |
-
x = self.ff(self.norm3(x)) + x
|
346 |
-
return x
|
347 |
-
|
348 |
-
ATTENTION_MODES = {
|
349 |
-
"softmax": CrossAttention, # vanilla attention
|
350 |
-
"softmax-xformers": MemoryEfficientCrossAttention,
|
351 |
-
"softmax-flash": FlashAttention
|
352 |
-
}
|
353 |
-
|
354 |
-
def modulate(x, shift, scale):
|
355 |
-
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
356 |
-
|
357 |
-
|
358 |
-
class Fp32LayerNorm(nn.LayerNorm):
|
359 |
-
def __init__(self, *args, **kwargs):
|
360 |
-
super().__init__(*args, **kwargs)
|
361 |
-
def forward(self, x):
|
362 |
-
return super().forward(x.float()).type(x.dtype)
|
363 |
-
|
364 |
-
|
365 |
-
class AdaNorm(nn.Module):
|
366 |
-
def __init__(self, dim):
|
367 |
-
super().__init__()
|
368 |
-
self.adaLN_modulation = nn.Sequential(
|
369 |
-
nn.SiLU(),
|
370 |
-
nn.Linear(dim, 2 * dim, bias=True)
|
371 |
-
)
|
372 |
-
self.norm = Fp32LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
373 |
-
|
374 |
-
def forward(self, x, c): # x is fp32, c is fp16
|
375 |
-
shift, scale = self.adaLN_modulation(c.float()).chunk(2, dim=1) # bf16
|
376 |
-
x = modulate(self.norm(x), shift, scale) # fp32
|
377 |
-
return x
|
378 |
-
|
379 |
-
|
380 |
-
class BasicTransformerBlockLRM(nn.Module):
|
381 |
-
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, \
|
382 |
-
checkpoint=True):
|
383 |
-
super().__init__()
|
384 |
-
|
385 |
-
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
|
386 |
-
attn_mode = "softmax-flash" if FLASH_IS_AVAILABLE else attn_mode
|
387 |
-
assert attn_mode in ATTENTION_MODES
|
388 |
-
attn_cls = ATTENTION_MODES[attn_mode]
|
389 |
-
|
390 |
-
self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, \
|
391 |
-
context_dim=context_dim) # cross-attn
|
392 |
-
self.attn2 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, \
|
393 |
-
context_dim=None) # self-attn
|
394 |
-
|
395 |
-
self.norm1 = Fp32LayerNorm(dim)
|
396 |
-
self.norm2 = Fp32LayerNorm(dim)
|
397 |
-
self.norm3 = Fp32LayerNorm(dim)
|
398 |
-
|
399 |
-
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
400 |
-
self.checkpoint = checkpoint
|
401 |
-
|
402 |
-
def forward(self, x, context=None, cam_emb=None): # (torch.float32, torch.float32, torch.bfloat16)
|
403 |
-
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
404 |
-
|
405 |
-
|
406 |
-
def _forward(self, x, context=None, cam_emb=None):
|
407 |
-
|
408 |
-
x = self.attn1(self.norm1(x), context=context) + x # cross-attn
|
409 |
-
x = self.attn2(self.norm2(x), context=None) + x # self-attn
|
410 |
-
x = self.ff(self.norm3(x)) + x
|
411 |
-
|
412 |
-
return x
|
413 |
-
|
414 |
-
class ImgToTriplaneTransformer(nn.Module):
|
415 |
-
"""
|
416 |
-
Transformer block for image-like data.
|
417 |
-
First, project the input (aka embedding)
|
418 |
-
and reshape to b, t, d.
|
419 |
-
Then apply standard transformer action.
|
420 |
-
Finally, reshape to image
|
421 |
-
"""
|
422 |
-
def __init__(self, query_dim, n_heads, d_head, depth=1, dropout=0., context_dim=None, triplane_size=64):
|
423 |
-
super().__init__()
|
424 |
-
|
425 |
-
self.transformer_blocks = nn.ModuleList([
|
426 |
-
BasicTransformerBlockLRM(query_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
427 |
-
for d in range(depth)])
|
428 |
-
|
429 |
-
self.norm = Fp32LayerNorm(query_dim, eps=1e-6)
|
430 |
-
|
431 |
-
self.initialize_weights()
|
432 |
-
|
433 |
-
def initialize_weights(self):
|
434 |
-
# Initialize transformer layers:
|
435 |
-
def _basic_init(module):
|
436 |
-
if isinstance(module, nn.Linear):
|
437 |
-
torch.nn.init.xavier_uniform_(module.weight)
|
438 |
-
if module.bias is not None:
|
439 |
-
nn.init.constant_(module.bias, 0)
|
440 |
-
elif isinstance(module, nn.LayerNorm):
|
441 |
-
if module.bias is not None:
|
442 |
-
nn.init.constant_(module.bias, 0)
|
443 |
-
if module.weight is not None:
|
444 |
-
nn.init.constant_(module.weight, 1.0)
|
445 |
-
self.apply(_basic_init)
|
446 |
-
|
447 |
-
def forward(self, x, context=None, cam_emb=None):
|
448 |
-
# note: if no context is given, cross-attention defaults to self-attention
|
449 |
-
for block in self.transformer_blocks:
|
450 |
-
x = block(x, context=context)
|
451 |
-
x = self.norm(x)
|
452 |
-
return x
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
svrm/ldm/modules/encoders/__init__.py
DELETED
File without changes
|
svrm/ldm/modules/encoders/dinov2/__init__.py
DELETED
File without changes
|
svrm/ldm/modules/encoders/dinov2/hub/__init__.py
DELETED
File without changes
|
svrm/ldm/modules/encoders/dinov2/hub/backbones.py
DELETED
@@ -1,156 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
#
|
3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
5 |
-
|
6 |
-
from enum import Enum
|
7 |
-
from typing import Union
|
8 |
-
|
9 |
-
import torch
|
10 |
-
|
11 |
-
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
|
12 |
-
|
13 |
-
|
14 |
-
class Weights(Enum):
|
15 |
-
LVD142M = "LVD142M"
|
16 |
-
|
17 |
-
|
18 |
-
def _make_dinov2_model(
|
19 |
-
*,
|
20 |
-
arch_name: str = "vit_large",
|
21 |
-
img_size: int = 518,
|
22 |
-
patch_size: int = 14,
|
23 |
-
init_values: float = 1.0,
|
24 |
-
ffn_layer: str = "mlp",
|
25 |
-
block_chunks: int = 0,
|
26 |
-
num_register_tokens: int = 0,
|
27 |
-
interpolate_antialias: bool = False,
|
28 |
-
interpolate_offset: float = 0.1,
|
29 |
-
pretrained: bool = True,
|
30 |
-
weights: Union[Weights, str] = Weights.LVD142M,
|
31 |
-
**kwargs,
|
32 |
-
):
|
33 |
-
from ..models import vision_transformer as vits
|
34 |
-
|
35 |
-
if isinstance(weights, str):
|
36 |
-
try:
|
37 |
-
weights = Weights[weights]
|
38 |
-
except KeyError:
|
39 |
-
raise AssertionError(f"Unsupported weights: {weights}")
|
40 |
-
|
41 |
-
model_base_name = _make_dinov2_model_name(arch_name, patch_size)
|
42 |
-
vit_kwargs = dict(
|
43 |
-
img_size=img_size,
|
44 |
-
patch_size=patch_size,
|
45 |
-
init_values=init_values,
|
46 |
-
ffn_layer=ffn_layer,
|
47 |
-
block_chunks=block_chunks,
|
48 |
-
num_register_tokens=num_register_tokens,
|
49 |
-
interpolate_antialias=interpolate_antialias,
|
50 |
-
interpolate_offset=interpolate_offset,
|
51 |
-
)
|
52 |
-
vit_kwargs.update(**kwargs)
|
53 |
-
model = vits.__dict__[arch_name](**vit_kwargs)
|
54 |
-
|
55 |
-
if pretrained:
|
56 |
-
model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
|
57 |
-
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
|
58 |
-
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
59 |
-
model.load_state_dict(state_dict, strict=True)
|
60 |
-
|
61 |
-
return model
|
62 |
-
|
63 |
-
|
64 |
-
def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
65 |
-
"""
|
66 |
-
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
|
67 |
-
"""
|
68 |
-
return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
|
69 |
-
|
70 |
-
|
71 |
-
def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
72 |
-
"""
|
73 |
-
DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
|
74 |
-
"""
|
75 |
-
return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
|
76 |
-
|
77 |
-
|
78 |
-
def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
79 |
-
"""
|
80 |
-
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
|
81 |
-
"""
|
82 |
-
return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
|
83 |
-
|
84 |
-
|
85 |
-
def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
86 |
-
"""
|
87 |
-
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
|
88 |
-
"""
|
89 |
-
return _make_dinov2_model(
|
90 |
-
arch_name="vit_giant2",
|
91 |
-
ffn_layer="swiglufused",
|
92 |
-
weights=weights,
|
93 |
-
pretrained=pretrained,
|
94 |
-
**kwargs,
|
95 |
-
)
|
96 |
-
|
97 |
-
|
98 |
-
def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
99 |
-
"""
|
100 |
-
DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
101 |
-
"""
|
102 |
-
return _make_dinov2_model(
|
103 |
-
arch_name="vit_small",
|
104 |
-
pretrained=pretrained,
|
105 |
-
weights=weights,
|
106 |
-
num_register_tokens=4,
|
107 |
-
interpolate_antialias=True,
|
108 |
-
interpolate_offset=0.0,
|
109 |
-
**kwargs,
|
110 |
-
)
|
111 |
-
|
112 |
-
|
113 |
-
def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
114 |
-
"""
|
115 |
-
DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
116 |
-
"""
|
117 |
-
return _make_dinov2_model(
|
118 |
-
arch_name="vit_base",
|
119 |
-
pretrained=pretrained,
|
120 |
-
weights=weights,
|
121 |
-
num_register_tokens=4,
|
122 |
-
interpolate_antialias=True,
|
123 |
-
interpolate_offset=0.0,
|
124 |
-
**kwargs,
|
125 |
-
)
|
126 |
-
|
127 |
-
|
128 |
-
def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
129 |
-
"""
|
130 |
-
DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
131 |
-
"""
|
132 |
-
return _make_dinov2_model(
|
133 |
-
arch_name="vit_large",
|
134 |
-
pretrained=pretrained,
|
135 |
-
weights=weights,
|
136 |
-
num_register_tokens=4,
|
137 |
-
interpolate_antialias=True,
|
138 |
-
interpolate_offset=0.0,
|
139 |
-
**kwargs,
|
140 |
-
)
|
141 |
-
|
142 |
-
|
143 |
-
def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
144 |
-
"""
|
145 |
-
DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
146 |
-
"""
|
147 |
-
return _make_dinov2_model(
|
148 |
-
arch_name="vit_giant2",
|
149 |
-
ffn_layer="swiglufused",
|
150 |
-
weights=weights,
|
151 |
-
pretrained=pretrained,
|
152 |
-
num_register_tokens=4,
|
153 |
-
interpolate_antialias=True,
|
154 |
-
interpolate_offset=0.0,
|
155 |
-
**kwargs,
|
156 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
svrm/ldm/modules/encoders/dinov2/hub/utils.py
DELETED
@@ -1,39 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
#
|
3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
5 |
-
|
6 |
-
import itertools
|
7 |
-
import math
|
8 |
-
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
import torch.nn.functional as F
|
12 |
-
|
13 |
-
|
14 |
-
_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
|
15 |
-
|
16 |
-
|
17 |
-
def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
|
18 |
-
compact_arch_name = arch_name.replace("_", "")[:4]
|
19 |
-
registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
|
20 |
-
return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
|
21 |
-
|
22 |
-
|
23 |
-
class CenterPadding(nn.Module):
|
24 |
-
def __init__(self, multiple):
|
25 |
-
super().__init__()
|
26 |
-
self.multiple = multiple
|
27 |
-
|
28 |
-
def _get_pad(self, size):
|
29 |
-
new_size = math.ceil(size / self.multiple) * self.multiple
|
30 |
-
pad_size = new_size - size
|
31 |
-
pad_size_left = pad_size // 2
|
32 |
-
pad_size_right = pad_size - pad_size_left
|
33 |
-
return pad_size_left, pad_size_right
|
34 |
-
|
35 |
-
@torch.inference_mode()
|
36 |
-
def forward(self, x):
|
37 |
-
pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
|
38 |
-
output = F.pad(x, pads)
|
39 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
svrm/ldm/modules/encoders/dinov2/layers/__init__.py
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
#
|
3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
5 |
-
|
6 |
-
from .dino_head import DINOHead
|
7 |
-
from .mlp import Mlp
|
8 |
-
from .patch_embed import PatchEmbed
|
9 |
-
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
10 |
-
from .block import NestedTensorBlockMod
|
11 |
-
from .attention import MemEffAttention
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
svrm/ldm/modules/encoders/dinov2/layers/attention.py
DELETED
@@ -1,89 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
#
|
3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
5 |
-
|
6 |
-
# References:
|
7 |
-
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
-
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
-
|
10 |
-
import logging
|
11 |
-
import os
|
12 |
-
import warnings
|
13 |
-
|
14 |
-
from torch import Tensor
|
15 |
-
from torch import nn
|
16 |
-
|
17 |
-
|
18 |
-
logger = logging.getLogger("dinov2")
|
19 |
-
|
20 |
-
|
21 |
-
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
22 |
-
try:
|
23 |
-
if XFORMERS_ENABLED:
|
24 |
-
from xformers.ops import memory_efficient_attention, unbind
|
25 |
-
|
26 |
-
XFORMERS_AVAILABLE = True
|
27 |
-
warnings.warn("xFormers is available (Attention)")
|
28 |
-
else:
|
29 |
-
warnings.warn("xFormers is disabled (Attention)")
|
30 |
-
raise ImportError
|
31 |
-
except ImportError:
|
32 |
-
XFORMERS_AVAILABLE = False
|
33 |
-
warnings.warn("xFormers is not available (Attention)")
|
34 |
-
|
35 |
-
|
36 |
-
class Attention(nn.Module):
|
37 |
-
def __init__(
|
38 |
-
self,
|
39 |
-
dim: int,
|
40 |
-
num_heads: int = 8,
|
41 |
-
qkv_bias: bool = False,
|
42 |
-
proj_bias: bool = True,
|
43 |
-
attn_drop: float = 0.0,
|
44 |
-
proj_drop: float = 0.0,
|
45 |
-
) -> None:
|
46 |
-
super().__init__()
|
47 |
-
self.num_heads = num_heads
|
48 |
-
head_dim = dim // num_heads
|
49 |
-
self.scale = head_dim**-0.5
|
50 |
-
|
51 |
-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
52 |
-
self.attn_drop = nn.Dropout(attn_drop)
|
53 |
-
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
54 |
-
self.proj_drop = nn.Dropout(proj_drop)
|
55 |
-
|
56 |
-
def forward(self, x: Tensor) -> Tensor:
|
57 |
-
B, N, C = x.shape
|
58 |
-
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
59 |
-
|
60 |
-
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
61 |
-
attn = q @ k.transpose(-2, -1)
|
62 |
-
|
63 |
-
attn = attn.softmax(dim=-1)
|
64 |
-
attn = self.attn_drop(attn)
|
65 |
-
|
66 |
-
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
67 |
-
x = self.proj(x)
|
68 |
-
x = self.proj_drop(x)
|
69 |
-
return x
|
70 |
-
|
71 |
-
|
72 |
-
class MemEffAttention(Attention):
|
73 |
-
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
74 |
-
if not XFORMERS_AVAILABLE:
|
75 |
-
if attn_bias is not None:
|
76 |
-
raise AssertionError("xFormers is required for using nested tensors")
|
77 |
-
return super().forward(x)
|
78 |
-
|
79 |
-
B, N, C = x.shape
|
80 |
-
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
81 |
-
|
82 |
-
q, k, v = unbind(qkv, 2)
|
83 |
-
|
84 |
-
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
85 |
-
x = x.reshape([B, N, C])
|
86 |
-
|
87 |
-
x = self.proj(x)
|
88 |
-
x = self.proj_drop(x)
|
89 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
svrm/ldm/modules/encoders/dinov2/layers/block.py
DELETED
@@ -1,269 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
#
|
3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
5 |
-
|
6 |
-
# References:
|
7 |
-
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
-
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
9 |
-
|
10 |
-
import os
|
11 |
-
import logging
|
12 |
-
import warnings
|
13 |
-
from typing import Callable, List, Any, Tuple, Dict
|
14 |
-
|
15 |
-
import torch
|
16 |
-
from torch import nn, Tensor
|
17 |
-
|
18 |
-
from .attention import Attention, MemEffAttention
|
19 |
-
from .drop_path import DropPath
|
20 |
-
from .layer_scale import LayerScale
|
21 |
-
from .mlp import Mlp
|
22 |
-
|
23 |
-
from ....attention import AdaNorm
|
24 |
-
|
25 |
-
|
26 |
-
logger = logging.getLogger("dinov2")
|
27 |
-
|
28 |
-
|
29 |
-
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
30 |
-
try:
|
31 |
-
if XFORMERS_ENABLED:
|
32 |
-
from xformers.ops import fmha, scaled_index_add, index_select_cat
|
33 |
-
|
34 |
-
XFORMERS_AVAILABLE = True
|
35 |
-
warnings.warn("xFormers is available (Block)")
|
36 |
-
else:
|
37 |
-
warnings.warn("xFormers is disabled (Block)")
|
38 |
-
raise ImportError
|
39 |
-
except ImportError:
|
40 |
-
XFORMERS_AVAILABLE = False
|
41 |
-
|
42 |
-
warnings.warn("xFormers is not available (Block)")
|
43 |
-
|
44 |
-
|
45 |
-
class BlockMod(nn.Module):
|
46 |
-
'''
|
47 |
-
using Modified Block, see below
|
48 |
-
'''
|
49 |
-
def __init__(
|
50 |
-
self,
|
51 |
-
dim: int,
|
52 |
-
num_heads: int,
|
53 |
-
mlp_ratio: float = 4.0,
|
54 |
-
qkv_bias: bool = False,
|
55 |
-
proj_bias: bool = True,
|
56 |
-
ffn_bias: bool = True,
|
57 |
-
drop: float = 0.0,
|
58 |
-
attn_drop: float = 0.0,
|
59 |
-
init_values=None,
|
60 |
-
drop_path: float = 0.0,
|
61 |
-
act_layer: Callable[..., nn.Module] = nn.GELU,
|
62 |
-
norm_layer: Callable[..., nn.Module] = AdaNorm,
|
63 |
-
attn_class: Callable[..., nn.Module] = Attention,
|
64 |
-
ffn_layer: Callable[..., nn.Module] = Mlp,
|
65 |
-
) -> None:
|
66 |
-
super().__init__()
|
67 |
-
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
68 |
-
self.norm1 = norm_layer(dim)
|
69 |
-
self.attn = attn_class(
|
70 |
-
dim,
|
71 |
-
num_heads=num_heads,
|
72 |
-
qkv_bias=qkv_bias,
|
73 |
-
proj_bias=proj_bias,
|
74 |
-
attn_drop=attn_drop,
|
75 |
-
proj_drop=drop,
|
76 |
-
)
|
77 |
-
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
78 |
-
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
79 |
-
|
80 |
-
self.norm2 = norm_layer(dim)
|
81 |
-
mlp_hidden_dim = int(dim * mlp_ratio)
|
82 |
-
self.mlp = ffn_layer(
|
83 |
-
in_features=dim,
|
84 |
-
hidden_features=mlp_hidden_dim,
|
85 |
-
act_layer=act_layer,
|
86 |
-
drop=drop,
|
87 |
-
bias=ffn_bias,
|
88 |
-
)
|
89 |
-
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
90 |
-
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
91 |
-
|
92 |
-
self.sample_drop_ratio = drop_path
|
93 |
-
|
94 |
-
def forward(self, x: Tensor, cam_emb: Tensor) -> Tensor:
|
95 |
-
def attn_residual_func(x: Tensor, cam_emb: Tensor = None) -> Tensor:
|
96 |
-
return self.ls1(self.attn(self.norm1(x, cam_emb)))
|
97 |
-
|
98 |
-
def ffn_residual_func(x: Tensor, cam_emb: Tensor = None) -> Tensor:
|
99 |
-
return self.ls2(self.mlp(self.norm2(x, cam_emb)))
|
100 |
-
|
101 |
-
if self.training and self.sample_drop_ratio > 0.1:
|
102 |
-
# the overhead is compensated only for a drop path rate larger than 0.1
|
103 |
-
x = drop_add_residual_stochastic_depth(
|
104 |
-
x,
|
105 |
-
residual_func=attn_residual_func,
|
106 |
-
sample_drop_ratio=self.sample_drop_ratio,
|
107 |
-
)
|
108 |
-
x = drop_add_residual_stochastic_depth(
|
109 |
-
x,
|
110 |
-
residual_func=ffn_residual_func,
|
111 |
-
sample_drop_ratio=self.sample_drop_ratio,
|
112 |
-
)
|
113 |
-
elif self.training and self.sample_drop_ratio > 0.0:
|
114 |
-
x = x + self.drop_path1(attn_residual_func(x, cam_emb))
|
115 |
-
x = x + self.drop_path1(ffn_residual_func(x, cam_emb)) # FIXME: drop_path2
|
116 |
-
else:
|
117 |
-
x = x + attn_residual_func(x, cam_emb)
|
118 |
-
x = x + ffn_residual_func(x, cam_emb)
|
119 |
-
return x
|
120 |
-
|
121 |
-
|
122 |
-
def drop_add_residual_stochastic_depth(
|
123 |
-
x: Tensor,
|
124 |
-
residual_func: Callable[[Tensor], Tensor],
|
125 |
-
sample_drop_ratio: float = 0.0,
|
126 |
-
) -> Tensor:
|
127 |
-
# drop_add_residual_stochastic_depth_list
|
128 |
-
|
129 |
-
# 1) extract subset using permutation
|
130 |
-
b, n, d = x.shape
|
131 |
-
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
132 |
-
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
133 |
-
x_subset = x[brange]
|
134 |
-
|
135 |
-
# 2) apply residual_func to get residual
|
136 |
-
residual = residual_func(x_subset)
|
137 |
-
|
138 |
-
x_flat = x.flatten(1)
|
139 |
-
residual = residual.flatten(1)
|
140 |
-
|
141 |
-
residual_scale_factor = b / sample_subset_size
|
142 |
-
|
143 |
-
# 3) add the residual
|
144 |
-
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
145 |
-
return x_plus_residual.view_as(x)
|
146 |
-
|
147 |
-
|
148 |
-
def get_branges_scales(x, sample_drop_ratio=0.0):
|
149 |
-
# get_branges_scales
|
150 |
-
b, n, d = x.shape
|
151 |
-
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
152 |
-
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
153 |
-
residual_scale_factor = b / sample_subset_size
|
154 |
-
return brange, residual_scale_factor
|
155 |
-
|
156 |
-
|
157 |
-
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
158 |
-
# add residuals
|
159 |
-
if scaling_vector is None:
|
160 |
-
x_flat = x.flatten(1)
|
161 |
-
residual = residual.flatten(1)
|
162 |
-
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
163 |
-
else:
|
164 |
-
x_plus_residual = scaled_index_add(
|
165 |
-
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
166 |
-
)
|
167 |
-
return x_plus_residual
|
168 |
-
|
169 |
-
|
170 |
-
attn_bias_cache: Dict[Tuple, Any] = {}
|
171 |
-
|
172 |
-
|
173 |
-
def get_attn_bias_and_cat(x_list, branges=None):
|
174 |
-
"""
|
175 |
-
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
176 |
-
"""
|
177 |
-
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
178 |
-
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
179 |
-
if all_shapes not in attn_bias_cache.keys():
|
180 |
-
seqlens = []
|
181 |
-
for b, x in zip(batch_sizes, x_list):
|
182 |
-
for _ in range(b):
|
183 |
-
seqlens.append(x.shape[1])
|
184 |
-
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
185 |
-
attn_bias._batch_sizes = batch_sizes
|
186 |
-
attn_bias_cache[all_shapes] = attn_bias
|
187 |
-
|
188 |
-
if branges is not None:
|
189 |
-
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
190 |
-
else:
|
191 |
-
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
192 |
-
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
193 |
-
|
194 |
-
return attn_bias_cache[all_shapes], cat_tensors
|
195 |
-
|
196 |
-
|
197 |
-
def drop_add_residual_stochastic_list(
|
198 |
-
x_list: List[Tensor],
|
199 |
-
residual_func: Callable[[Tensor, Any], Tensor],
|
200 |
-
sample_drop_ratio: float = 0.0,
|
201 |
-
scaling_vector=None,
|
202 |
-
) -> Tensor:
|
203 |
-
# 1) generate random set of indices for dropping samples in the batch
|
204 |
-
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
205 |
-
branges = [s[0] for s in branges_scales]
|
206 |
-
residual_scale_factors = [s[1] for s in branges_scales]
|
207 |
-
|
208 |
-
# 2) get attention bias and index+concat the tensors
|
209 |
-
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
210 |
-
|
211 |
-
# 3) apply residual_func to get residual, and split the result
|
212 |
-
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
213 |
-
|
214 |
-
outputs = []
|
215 |
-
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
216 |
-
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
217 |
-
return outputs
|
218 |
-
|
219 |
-
|
220 |
-
class NestedTensorBlockMod(BlockMod):
|
221 |
-
def forward_nested(self, x_list: List[Tensor], cam_emb_list: List[Tensor]) -> List[Tensor]:
|
222 |
-
"""
|
223 |
-
x_list contains a list of tensors to nest together and run
|
224 |
-
"""
|
225 |
-
assert isinstance(self.attn, MemEffAttention)
|
226 |
-
|
227 |
-
if self.training and self.sample_drop_ratio > 0.0:
|
228 |
-
|
229 |
-
def attn_residual_func(x: Tensor, cam_emb: Tensor, attn_bias=None) -> Tensor:
|
230 |
-
return self.attn(self.norm1(x, cam_emb), attn_bias=attn_bias)
|
231 |
-
|
232 |
-
def ffn_residual_func(x: Tensor, cam_emb: Tensor, attn_bias=None) -> Tensor:
|
233 |
-
return self.mlp(self.norm2(x, cam_emb))
|
234 |
-
|
235 |
-
x_list = drop_add_residual_stochastic_list(
|
236 |
-
x_list,
|
237 |
-
residual_func=attn_residual_func,
|
238 |
-
sample_drop_ratio=self.sample_drop_ratio,
|
239 |
-
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
240 |
-
)
|
241 |
-
x_list = drop_add_residual_stochastic_list(
|
242 |
-
x_list,
|
243 |
-
residual_func=ffn_residual_func,
|
244 |
-
sample_drop_ratio=self.sample_drop_ratio,
|
245 |
-
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
246 |
-
)
|
247 |
-
return x_list
|
248 |
-
else:
|
249 |
-
|
250 |
-
def attn_residual_func(x: Tensor, cam_emb: Tensor, attn_bias=None) -> Tensor:
|
251 |
-
return self.ls1(self.attn(self.norm1(x, cam_emb), attn_bias=attn_bias))
|
252 |
-
|
253 |
-
def ffn_residual_func(x: Tensor, cam_emb: Tensor, attn_bias=None) -> Tensor:
|
254 |
-
return self.ls2(self.mlp(self.norm2(x, cam_emb)))
|
255 |
-
|
256 |
-
attn_bias, x = get_attn_bias_and_cat(x_list)
|
257 |
-
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
258 |
-
x = x + ffn_residual_func(x)
|
259 |
-
return attn_bias.split(x)
|
260 |
-
|
261 |
-
def forward(self, x_or_x_list, cam_emb_or_cam_emb_list):
|
262 |
-
if isinstance(x_or_x_list, Tensor) and isinstance(cam_emb_or_cam_emb_list, Tensor) :
|
263 |
-
return super().forward(x_or_x_list, cam_emb_or_cam_emb_list)
|
264 |
-
elif isinstance(x_or_x_list, list) and isinstance(cam_emb_or_cam_emb_list, list):
|
265 |
-
if not XFORMERS_AVAILABLE:
|
266 |
-
raise AssertionError("xFormers is required for using nested tensors")
|
267 |
-
return self.forward_nested(x_or_x_list, cam_emb_or_cam_emb_list)
|
268 |
-
else:
|
269 |
-
raise AssertionError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
svrm/ldm/modules/encoders/dinov2/layers/dino_head.py
DELETED
@@ -1,58 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
#
|
3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
5 |
-
|
6 |
-
import torch
|
7 |
-
import torch.nn as nn
|
8 |
-
from torch.nn.init import trunc_normal_
|
9 |
-
from torch.nn.utils import weight_norm
|
10 |
-
|
11 |
-
|
12 |
-
class DINOHead(nn.Module):
|
13 |
-
def __init__(
|
14 |
-
self,
|
15 |
-
in_dim,
|
16 |
-
out_dim,
|
17 |
-
use_bn=False,
|
18 |
-
nlayers=3,
|
19 |
-
hidden_dim=2048,
|
20 |
-
bottleneck_dim=256,
|
21 |
-
mlp_bias=True,
|
22 |
-
):
|
23 |
-
super().__init__()
|
24 |
-
nlayers = max(nlayers, 1)
|
25 |
-
self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
|
26 |
-
self.apply(self._init_weights)
|
27 |
-
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
28 |
-
self.last_layer.weight_g.data.fill_(1)
|
29 |
-
|
30 |
-
def _init_weights(self, m):
|
31 |
-
if isinstance(m, nn.Linear):
|
32 |
-
trunc_normal_(m.weight, std=0.02)
|
33 |
-
if isinstance(m, nn.Linear) and m.bias is not None:
|
34 |
-
nn.init.constant_(m.bias, 0)
|
35 |
-
|
36 |
-
def forward(self, x):
|
37 |
-
x = self.mlp(x)
|
38 |
-
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
39 |
-
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
40 |
-
x = self.last_layer(x)
|
41 |
-
return x
|
42 |
-
|
43 |
-
|
44 |
-
def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
|
45 |
-
if nlayers == 1:
|
46 |
-
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
47 |
-
else:
|
48 |
-
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
49 |
-
if use_bn:
|
50 |
-
layers.append(nn.BatchNorm1d(hidden_dim))
|
51 |
-
layers.append(nn.GELU())
|
52 |
-
for _ in range(nlayers - 2):
|
53 |
-
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
54 |
-
if use_bn:
|
55 |
-
layers.append(nn.BatchNorm1d(hidden_dim))
|
56 |
-
layers.append(nn.GELU())
|
57 |
-
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
58 |
-
return nn.Sequential(*layers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
svrm/ldm/modules/encoders/dinov2/layers/drop_path.py
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
#
|
3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
5 |
-
|
6 |
-
# References:
|
7 |
-
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
-
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
9 |
-
|
10 |
-
|
11 |
-
from torch import nn
|
12 |
-
|
13 |
-
|
14 |
-
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
15 |
-
if drop_prob == 0.0 or not training:
|
16 |
-
return x
|
17 |
-
keep_prob = 1 - drop_prob
|
18 |
-
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
19 |
-
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
20 |
-
if keep_prob > 0.0:
|
21 |
-
random_tensor.div_(keep_prob)
|
22 |
-
output = x * random_tensor
|
23 |
-
return output
|
24 |
-
|
25 |
-
|
26 |
-
class DropPath(nn.Module):
|
27 |
-
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
28 |
-
|
29 |
-
def __init__(self, drop_prob=None):
|
30 |
-
super(DropPath, self).__init__()
|
31 |
-
self.drop_prob = drop_prob
|
32 |
-
|
33 |
-
def forward(self, x):
|
34 |
-
return drop_path(x, self.drop_prob, self.training)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
svrm/ldm/modules/encoders/dinov2/layers/layer_scale.py
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
#
|
3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
5 |
-
|
6 |
-
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
7 |
-
|
8 |
-
from typing import Union
|
9 |
-
|
10 |
-
import torch
|
11 |
-
from torch import Tensor
|
12 |
-
from torch import nn
|
13 |
-
|
14 |
-
|
15 |
-
class LayerScale(nn.Module):
|
16 |
-
def __init__(
|
17 |
-
self,
|
18 |
-
dim: int,
|
19 |
-
init_values: Union[float, Tensor] = 1e-5,
|
20 |
-
inplace: bool = False,
|
21 |
-
) -> None:
|
22 |
-
super().__init__()
|
23 |
-
self.inplace = inplace
|
24 |
-
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
25 |
-
|
26 |
-
def forward(self, x: Tensor) -> Tensor:
|
27 |
-
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
svrm/ldm/modules/encoders/dinov2/layers/mlp.py
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
#
|
3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
5 |
-
|
6 |
-
# References:
|
7 |
-
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
-
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
9 |
-
|
10 |
-
|
11 |
-
from typing import Callable, Optional
|
12 |
-
|
13 |
-
from torch import Tensor, nn
|
14 |
-
|
15 |
-
|
16 |
-
class Mlp(nn.Module):
|
17 |
-
def __init__(
|
18 |
-
self,
|
19 |
-
in_features: int,
|
20 |
-
hidden_features: Optional[int] = None,
|
21 |
-
out_features: Optional[int] = None,
|
22 |
-
act_layer: Callable[..., nn.Module] = nn.GELU,
|
23 |
-
drop: float = 0.0,
|
24 |
-
bias: bool = True,
|
25 |
-
) -> None:
|
26 |
-
super().__init__()
|
27 |
-
out_features = out_features or in_features
|
28 |
-
hidden_features = hidden_features or in_features
|
29 |
-
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
30 |
-
self.act = act_layer()
|
31 |
-
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
32 |
-
self.drop = nn.Dropout(drop)
|
33 |
-
|
34 |
-
def forward(self, x: Tensor) -> Tensor:
|
35 |
-
x = self.fc1(x)
|
36 |
-
x = self.act(x)
|
37 |
-
x = self.drop(x)
|
38 |
-
x = self.fc2(x)
|
39 |
-
x = self.drop(x)
|
40 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|