anvilinteractiv commited on
Commit
3a9b68a
·
verified ·
1 Parent(s): 6f49341

Upload 2 files

Browse files
Files changed (2) hide show
  1. api.py +271 -0
  2. main (1).py +14 -0
api.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ import trimesh
7
+ import random
8
+ from transformers import AutoModelForImageSegmentation
9
+ from torchvision import transforms
10
+ from huggingface_hub import hf_hub_download, snapshot_download
11
+ import subprocess
12
+ import shutil
13
+ from fastapi import FastAPI, HTTPException, Depends, File, UploadFile
14
+ from fastapi.security import APIKeyHeader
15
+ from fastapi.staticfiles import StaticFiles
16
+ from pydantic import BaseModel
17
+ import uvicorn
18
+
19
+ # Install additional dependencies
20
+ subprocess.run("pip install spandrel==0.4.1 --no-deps", shell=True, check=True)
21
+ subprocess.run("pip install fastapi uvicorn", shell=True, check=True)
22
+
23
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
+ DTYPE = torch.float16
25
+
26
+ print("DEVICE: ", DEVICE)
27
+
28
+ DEFAULT_FACE_NUMBER = 100000
29
+ MAX_SEED = np.iinfo(np.int32).max
30
+ TRIPOSG_REPO_URL = "https://github.com/VAST-AI-Research/TripoSG.git"
31
+ MV_ADAPTER_REPO_URL = "https://github.com/huanngzh/MV-Adapter.git"
32
+
33
+ RMBG_PRETRAINED_MODEL = "checkpoints/RMBG-1.4"
34
+ TRIPOSG_PRETRAINED_MODEL = "checkpoints/TripoSG"
35
+
36
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
37
+ os.makedirs(TMP_DIR, exist_ok=True)
38
+
39
+ TRIPOSG_CODE_DIR = "./triposg"
40
+ if not os.path.exists(TRIPOSG_CODE_DIR):
41
+ os.system(f"git clone {TRIPOSG_REPO_URL} {TRIPOSG_CODE_DIR}")
42
+
43
+ MV_ADAPTER_CODE_DIR = "./mv_adapter"
44
+ if not os.path.exists(MV_ADAPTER_CODE_DIR):
45
+ os.system(f"git clone {MV_ADAPTER_REPO_URL} {MV_ADAPTER_CODE_DIR} && cd {MV_ADAPTER_CODE_DIR} && git checkout 7d37a97e9bc223cdb8fd26a76bd8dd46504c7c3d")
46
+
47
+ import sys
48
+ sys.path.append(TRIPOSG_CODE_DIR)
49
+ sys.path.append(os.path.join(TRIPOSG_CODE_DIR, "scripts"))
50
+ sys.path.append(MV_ADAPTER_CODE_DIR)
51
+ sys.path.append(os.path.join(MV_ADAPTER_CODE_DIR, "scripts"))
52
+
53
+ # triposg
54
+ from image_process import prepare_image
55
+ from briarmbg import BriaRMBG
56
+ snapshot_download("briaai/RMBG-1.4", local_dir=RMBG_PRETRAINED_MODEL)
57
+ rmbg_net = BriaRMBG.from_pretrained(RMBG_PRETRAINED_MODEL).to(DEVICE)
58
+ rmbg_net.eval()
59
+ from triposg.pipelines.pipeline_triposg import TripoSGPipeline
60
+ snapshot_download("VAST-AI/TripoSG", local_dir=TRIPOSG_PRETRAINED_MODEL)
61
+ triposg_pipe = TripoSGPipeline.from_pretrained(TRIPOSG_PRETRAINED_MODEL).to(DEVICE, DTYPE)
62
+
63
+ # mv adapter
64
+ NUM_VIEWS = 6
65
+ from inference_ig2mv_sdxl import prepare_pipeline, preprocess_image, remove_bg
66
+ from mvadapter.utils import get_orthogonal_camera, tensor_to_image, make_image_grid
67
+ from mvadapter.utils.render import NVDiffRastContextWrapper, load_mesh, render
68
+ mv_adapter_pipe = prepare_pipeline(
69
+ base_model="stabilityai/stable-diffusion-xl-base-1.0",
70
+ vae_model="madebyollin/sdxl-vae-fp16-fix",
71
+ unet_model=None,
72
+ lora_model=None,
73
+ adapter_path="huanngzh/mv-adapter",
74
+ scheduler=None,
75
+ num_views=NUM_VIEWS,
76
+ device=DEVICE,
77
+ dtype=torch.float16,
78
+ )
79
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
80
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
81
+ ).to(DEVICE)
82
+ transform_image = transforms.Compose(
83
+ [
84
+ transforms.Resize((1024, 1024)),
85
+ transforms.ToTensor(),
86
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
87
+ ]
88
+ )
89
+ remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, DEVICE)
90
+
91
+ if not os.path.exists("checkpoints/RealESRGAN_x2plus.pth"):
92
+ hf_hub_download("dtarnow/UPscaler", filename="RealESRGAN_x2plus.pth", local_dir="checkpoints")
93
+ if not os.path.exists("checkpoints/big-lama.pt"):
94
+ subprocess.run("wget -P checkpoints/ https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", shell=True, check=True)
95
+
96
+ # Initialize FastAPI app
97
+ app = FastAPI()
98
+
99
+ # Mount static files for serving generated models
100
+ app.mount("/files", StaticFiles(directory=TMP_DIR), name="files")
101
+
102
+ # API key authentication
103
+ api_key_header = APIKeyHeader(name="X-API-Key")
104
+ VALID_API_KEY = os.getenv("POLYGENIX_API_KEY", "your-secret-api-key")
105
+
106
+ async def verify_api_key(api_key: str = Depends(api_key_header)):
107
+ if api_key != VALID_API_KEY:
108
+ raise HTTPException(status_code=401, detail="Invalid API key")
109
+ return api_key
110
+
111
+ # API request model
112
+ class GenerateRequest(BaseModel):
113
+ seed: int = 0
114
+ num_inference_steps: int = 50
115
+ guidance_scale: float = 7.5
116
+ simplify: bool = True
117
+ target_face_num: int = DEFAULT_FACE_NUMBER
118
+
119
+ # Test endpoint
120
+ @app.get("/api/test")
121
+ async def test_endpoint():
122
+ return {"message": "FastAPI is running"}
123
+
124
+ def get_random_hex():
125
+ random_bytes = os.urandom(8)
126
+ random_hex = random_bytes.hex()
127
+ return random_hex
128
+
129
+ @spaces.GPU(duration=180)
130
+ def run_full(image: str, req=None):
131
+ seed = 0
132
+ num_inference_steps = 50
133
+ guidance_scale = 7.5
134
+ simplify = True
135
+ target_face_num = DEFAULT_FACE_NUMBER
136
+
137
+ image_seg = prepare_image(image, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
138
+
139
+ outputs = triposg_pipe(
140
+ image=image_seg,
141
+ generator=torch.Generator(device=triposg_pipe.device).manual_seed(seed),
142
+ num_inference_steps=num_inference_steps,
143
+ guidance_scale=guidance_scale
144
+ ).samples[0]
145
+ print("mesh extraction done")
146
+ mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
147
+
148
+ if simplify:
149
+ print("start simplify")
150
+ from utils import simplify_mesh
151
+ mesh = simplify_mesh(mesh, target_face_num)
152
+
153
+ save_dir = os.path.join(TMP_DIR, "examples")
154
+ os.makedirs(save_dir, exist_ok=True)
155
+ mesh_path = os.path.join(save_dir, f"polygenixai_{get_random_hex()}.glb")
156
+ mesh.export(mesh_path)
157
+ print("save to ", mesh_path)
158
+
159
+ torch.cuda.empty_cache()
160
+
161
+ height, width = 768, 768
162
+ cameras = get_orthogonal_camera(
163
+ elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
164
+ distance=[1.8] * NUM_VIEWS,
165
+ left=-0.55,
166
+ right=0.55,
167
+ bottom=-0.55,
168
+ top=0.55,
169
+ azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
170
+ device=DEVICE,
171
+ )
172
+ ctx = NVDiffRastContextWrapper(device=DEVICE, context_type="cuda")
173
+
174
+ mesh = load_mesh(mesh_path, rescale=True, device=DEVICE)
175
+ render_out = render(
176
+ ctx,
177
+ mesh,
178
+ cameras,
179
+ height=height,
180
+ width=width,
181
+ render_attr=False,
182
+ normal_background=0.0,
183
+ )
184
+ control_images = (
185
+ torch.cat(
186
+ [
187
+ (render_out.pos + 0.5).clamp(0, 1),
188
+ (render_out.normal / 2 + 0.5).clamp(0, 1),
189
+ ],
190
+ dim=-1,
191
+ )
192
+ .permute(0, 3, 1, 2)
193
+ .to(DEVICE)
194
+ )
195
+
196
+ image = Image.open(image)
197
+ image = remove_bg_fn(image)
198
+ image = preprocess_image(image, height, width)
199
+
200
+ pipe_kwargs = {}
201
+ if seed != -1 and isinstance(seed, int):
202
+ pipe_kwargs["generator"] = torch.Generator(device=DEVICE).manual_seed(seed)
203
+
204
+ images = mv_adapter_pipe(
205
+ "high quality",
206
+ height=height,
207
+ width=width,
208
+ num_inference_steps=15,
209
+ guidance_scale=3.0,
210
+ num_images_per_prompt=NUM_VIEWS,
211
+ control_image=control_images,
212
+ control_conditioning_scale=1.0,
213
+ reference_image=image,
214
+ reference_conditioning_scale=1.0,
215
+ negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
216
+ cross_attention_kwargs={"scale": 1.0},
217
+ **pipe_kwargs,
218
+ ).images
219
+
220
+ torch.cuda.empty_cache()
221
+
222
+ mv_image_path = os.path.join(save_dir, f"polygenixai_mv_{get_random_hex()}.png")
223
+ make_image_grid(images, rows=1).save(mv_image_path)
224
+
225
+ from texture import TexturePipeline, ModProcessConfig
226
+ texture_pipe = TexturePipeline(
227
+ upscaler_ckpt_path="checkpoints/RealESRGAN_x2plus.pth",
228
+ inpaint_ckpt_path="checkpoints/big-lama.pt",
229
+ device=DEVICE,
230
+ )
231
+
232
+ textured_glb_path = texture_pipe(
233
+ mesh_path=mesh_path,
234
+ save_dir=save_dir,
235
+ save_name=f"polygenixai_texture_mesh_{get_random_hex()}.glb",
236
+ uv_unwarp=True,
237
+ uv_size=4096,
238
+ rgb_path=mv_image_path,
239
+ rgb_process_config=ModProcessConfig(view_upscale=True, inpaint_mode="view"),
240
+ camera_azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
241
+ )
242
+
243
+ return image_seg, mesh_path, textured_glb_path
244
+
245
+ # FastAPI endpoint for generating 3D models
246
+ @app.post("/api/generate")
247
+ async def generate_3d_model(request: GenerateRequest, image: UploadFile = File(...), api_key: str = Depends(verify_api_key)):
248
+ try:
249
+ # Save uploaded image to temporary directory
250
+ session_hash = get_random_hex()
251
+ save_dir = os.path.join(TMP_DIR, session_hash)
252
+ os.makedirs(save_dir, exist_ok=True)
253
+ image_path = os.path.join(save_dir, f"input_{get_random_hex()}.png")
254
+ with open(image_path, "wb") as f:
255
+ f.write(await image.read())
256
+
257
+ # Run the full pipeline
258
+ image_seg, mesh_path, textured_glb_path = run_full(image_path, req=None)
259
+
260
+ # Return the file URL for the textured GLB
261
+ file_url = f"/files/{session_hash}/{os.path.basename(textured_glb_path)}"
262
+ return {"file_url": file_url}
263
+ except Exception as e:
264
+ raise HTTPException(status_code=500, detail=str(e))
265
+ finally:
266
+ # Clean up temporary directory
267
+ if os.path.exists(save_dir):
268
+ shutil.rmtree(save_dir)
269
+
270
+ if __name__ == "__main__":
271
+ uvicorn.run(app, host="0.0.0.0", port=8000)
main (1).py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import uvicorn
3
+ from app import demo
4
+ from api import app
5
+
6
+ async def run_servers():
7
+ config = uvicorn.Config(app=app, host="0.0.0.0", port=8000)
8
+ server = uvicorn.Server(config)
9
+ fastapi_task = asyncio.create_task(server.serve())
10
+ demo.launch(server_name="0.0.0.0", server_port=7860)
11
+ await fastapi_task
12
+
13
+ if __name__ == "__main__":
14
+ asyncio.run(run_servers())