mac9087 commited on
Commit
84ed04d
·
verified ·
1 Parent(s): 1c064d0

Create bake_texture.py

Browse files
Files changed (1) hide show
  1. tsr/bake_texture.py +170 -0
tsr/bake_texture.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import xatlas
4
+ import trimesh
5
+ import moderngl
6
+ from PIL import Image
7
+
8
+
9
+ def make_atlas(mesh, texture_resolution, texture_padding):
10
+ atlas = xatlas.Atlas()
11
+ atlas.add_mesh(mesh.vertices, mesh.faces)
12
+ options = xatlas.PackOptions()
13
+ options.resolution = texture_resolution
14
+ options.padding = texture_padding
15
+ options.bilinear = True
16
+ atlas.generate(pack_options=options)
17
+ vmapping, indices, uvs = atlas[0]
18
+ return {
19
+ "vmapping": vmapping,
20
+ "indices": indices,
21
+ "uvs": uvs,
22
+ }
23
+
24
+
25
+ def rasterize_position_atlas(
26
+ mesh, atlas_vmapping, atlas_indices, atlas_uvs, texture_resolution, texture_padding
27
+ ):
28
+ ctx = moderngl.create_context(standalone=True)
29
+ basic_prog = ctx.program(
30
+ vertex_shader="""
31
+ #version 330
32
+ in vec2 in_uv;
33
+ in vec3 in_pos;
34
+ out vec3 v_pos;
35
+ void main() {
36
+ v_pos = in_pos;
37
+ gl_Position = vec4(in_uv * 2.0 - 1.0, 0.0, 1.0);
38
+ }
39
+ """,
40
+ fragment_shader="""
41
+ #version 330
42
+ in vec3 v_pos;
43
+ out vec4 o_col;
44
+ void main() {
45
+ o_col = vec4(v_pos, 1.0);
46
+ }
47
+ """,
48
+ )
49
+ gs_prog = ctx.program(
50
+ vertex_shader="""
51
+ #version 330
52
+ in vec2 in_uv;
53
+ in vec3 in_pos;
54
+ out vec3 vg_pos;
55
+ void main() {
56
+ vg_pos = in_pos;
57
+ gl_Position = vec4(in_uv * 2.0 - 1.0, 0.0, 1.0);
58
+ }
59
+ """,
60
+ geometry_shader="""
61
+ #version 330
62
+ uniform float u_resolution;
63
+ uniform float u_dilation;
64
+ layout (triangles) in;
65
+ layout (triangle_strip, max_vertices = 12) out;
66
+ in vec3 vg_pos[];
67
+ out vec3 vf_pos;
68
+ void lineSegment(int aidx, int bidx) {
69
+ vec2 a = gl_in[aidx].gl_Position.xy;
70
+ vec2 b = gl_in[bidx].gl_Position.xy;
71
+ vec3 aCol = vg_pos[aidx];
72
+ vec3 bCol = vg_pos[bidx];
73
+
74
+ vec2 dir = normalize((b - a) * u_resolution);
75
+ vec2 offset = vec2(-dir.y, dir.x) * u_dilation / u_resolution;
76
+
77
+ gl_Position = vec4(a + offset, 0.0, 1.0);
78
+ vf_pos = aCol;
79
+ EmitVertex();
80
+ gl_Position = vec4(a - offset, 0.0, 1.0);
81
+ vf_pos = aCol;
82
+ EmitVertex();
83
+ gl_Position = vec4(b + offset, 0.0, 1.0);
84
+ vf_pos = bCol;
85
+ EmitVertex();
86
+ gl_Position = vec4(b - offset, 0.0, 1.0);
87
+ vf_pos = bCol;
88
+ EmitVertex();
89
+ }
90
+ void main() {
91
+ lineSegment(0, 1);
92
+ lineSegment(1, 2);
93
+ lineSegment(2, 0);
94
+ EndPrimitive();
95
+ }
96
+ """,
97
+ fragment_shader="""
98
+ #version 330
99
+ in vec3 vf_pos;
100
+ out vec4 o_col;
101
+ void main() {
102
+ o_col = vec4(vf_pos, 1.0);
103
+ }
104
+ """,
105
+ )
106
+ uvs = atlas_uvs.flatten().astype("f4")
107
+ pos = mesh.vertices[atlas_vmapping].flatten().astype("f4")
108
+ indices = atlas_indices.flatten().astype("i4")
109
+ vbo_uvs = ctx.buffer(uvs)
110
+ vbo_pos = ctx.buffer(pos)
111
+ ibo = ctx.buffer(indices)
112
+ vao_content = [
113
+ vbo_uvs.bind("in_uv", layout="2f"),
114
+ vbo_pos.bind("in_pos", layout="3f"),
115
+ ]
116
+ basic_vao = ctx.vertex_array(basic_prog, vao_content, ibo)
117
+ gs_vao = ctx.vertex_array(gs_prog, vao_content, ibo)
118
+ fbo = ctx.framebuffer(
119
+ color_attachments=[
120
+ ctx.texture((texture_resolution, texture_resolution), 4, dtype="f4")
121
+ ]
122
+ )
123
+ fbo.use()
124
+ fbo.clear(0.0, 0.0, 0.0, 0.0)
125
+ gs_prog["u_resolution"].value = texture_resolution
126
+ gs_prog["u_dilation"].value = texture_padding
127
+ gs_vao.render()
128
+ basic_vao.render()
129
+
130
+ fbo_bytes = fbo.color_attachments[0].read()
131
+ fbo_np = np.frombuffer(fbo_bytes, dtype="f4").reshape(
132
+ texture_resolution, texture_resolution, 4
133
+ )
134
+ return fbo_np
135
+
136
+
137
+ def positions_to_colors(model, scene_code, positions_texture, texture_resolution):
138
+ positions = torch.tensor(positions_texture.reshape(-1, 4)[:, :-1])
139
+ with torch.no_grad():
140
+ queried_grid = model.renderer.query_triplane(
141
+ model.decoder,
142
+ positions,
143
+ scene_code,
144
+ )
145
+ rgb_f = queried_grid["color"].numpy().reshape(-1, 3)
146
+ rgba_f = np.insert(rgb_f, 3, positions_texture.reshape(-1, 4)[:, -1], axis=1)
147
+ rgba_f[rgba_f[:, -1] == 0.0] = [0, 0, 0, 0]
148
+ return rgba_f.reshape(texture_resolution, texture_resolution, 4)
149
+
150
+
151
+ def bake_texture(mesh, model, scene_code, texture_resolution):
152
+ texture_padding = round(max(2, texture_resolution / 256))
153
+ atlas = make_atlas(mesh, texture_resolution, texture_padding)
154
+ positions_texture = rasterize_position_atlas(
155
+ mesh,
156
+ atlas["vmapping"],
157
+ atlas["indices"],
158
+ atlas["uvs"],
159
+ texture_resolution,
160
+ texture_padding,
161
+ )
162
+ colors_texture = positions_to_colors(
163
+ model, scene_code, positions_texture, texture_resolution
164
+ )
165
+ return {
166
+ "vmapping": atlas["vmapping"],
167
+ "indices": atlas["indices"],
168
+ "uvs": atlas["uvs"],
169
+ "colors": colors_texture,
170
+ }