anvilinteractiv commited on
Commit
6aa6a22
·
verified ·
1 Parent(s): 4b73ccb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -17
app.py CHANGED
@@ -11,9 +11,14 @@ from torchvision import transforms
11
  from huggingface_hub import hf_hub_download, snapshot_download
12
  import subprocess
13
  import shutil
 
 
 
 
14
 
15
  # Install additional dependencies
16
  subprocess.run("pip install spandrel==0.4.1 --no-deps", shell=True, check=True)
 
17
 
18
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
  DTYPE = torch.float16
@@ -45,6 +50,34 @@ sys.path.append(os.path.join(TRIPOSG_CODE_DIR, "scripts"))
45
  sys.path.append(MV_ADAPTER_CODE_DIR)
46
  sys.path.append(os.path.join(MV_ADAPTER_CODE_DIR, "scripts"))
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  # triposg
49
  from image_process import prepare_image
50
  from briarmbg import BriaRMBG
@@ -88,27 +121,13 @@ if not os.path.exists("checkpoints/RealESRGAN_x2plus.pth"):
88
  if not os.path.exists("checkpoints/big-lama.pt"):
89
  subprocess.run("wget -P checkpoints/ https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", shell=True, check=True)
90
 
91
- def start_session(req: gr.Request):
92
- save_dir = os.path.join(TMP_DIR, str(req.session_hash))
93
- os.makedirs(save_dir, exist_ok=True)
94
- print("start session, mkdir", save_dir)
95
-
96
- def end_session(req: gr.Request):
97
- save_dir = os.path.join(TMP_DIR, str(req.session_hash))
98
- shutil.rmtree(save_dir)
99
-
100
  def get_random_hex():
101
  random_bytes = os.urandom(8)
102
  random_hex = random_bytes.hex()
103
  return random_hex
104
 
105
- def get_random_seed(randomize_seed, seed):
106
- if randomize_seed:
107
- seed = random.randint(0, MAX_SEED)
108
- return seed
109
-
110
  @spaces.GPU(duration=180)
111
- def run_full(image: str, req: gr.Request):
112
  seed = 0
113
  num_inference_steps = 50
114
  guidance_scale = 7.5
@@ -223,6 +242,45 @@ def run_full(image: str, req: gr.Request):
223
 
224
  return image_seg, mesh_path, textured_glb_path
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  @spaces.GPU()
227
  @torch.no_grad()
228
  def run_segmentation(image: str):
@@ -327,7 +385,6 @@ def run_texture(image: Image, mesh_path: str, seed: int, req: gr.Request):
327
 
328
  torch.cuda.empty_cache()
329
 
330
- save_dir = os.path.join(TMP_DIR, str(req.session_hash))
331
  mv_image_path = os.path.join(save_dir, f"polygenixai_mv_{get_random_hex()}.png")
332
  make_image_grid(images, rows=1).save(mv_image_path)
333
 
@@ -534,5 +591,9 @@ with gr.Blocks(title="PolyGenixAI", css="body { background-color: #1A1A1A; } .gr
534
  demo.load(start_session)
535
  demo.unload(end_session)
536
 
 
 
 
537
  if __name__ == "__main__":
538
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
11
  from huggingface_hub import hf_hub_download, snapshot_download
12
  import subprocess
13
  import shutil
14
+ from fastapi import FastAPI, HTTPException, Depends, File, UploadFile
15
+ from fastapi.security import APIKeyHeader
16
+ from fastapi.staticfiles import StaticFiles
17
+ from pydantic import BaseModel
18
 
19
  # Install additional dependencies
20
  subprocess.run("pip install spandrel==0.4.1 --no-deps", shell=True, check=True)
21
+ subprocess.run("pip install fastapi uvicorn", shell=True, check=True)
22
 
23
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
  DTYPE = torch.float16
 
50
  sys.path.append(MV_ADAPTER_CODE_DIR)
51
  sys.path.append(os.path.join(MV_ADAPTER_CODE_DIR, "scripts"))
52
 
53
+ # Initialize FastAPI app
54
+ app = FastAPI()
55
+
56
+ # Mount static files for serving generated models
57
+ app.mount("/files", StaticFiles(directory=TMP_DIR), name="files")
58
+
59
+ # API key authentication
60
+ api_key_header = APIKeyHeader(name="X-API-Key")
61
+ VALID_API_KEY = os.getenv("POLYGENIX_API_KEY", "your-secret-api-key")
62
+
63
+ async def verify_api_key(api_key: str = Depends(api_key_header)):
64
+ if api_key != VALID_API_KEY:
65
+ raise HTTPException(status_code=401, detail="Invalid API key")
66
+ return api_key
67
+
68
+ # API request model
69
+ class GenerateRequest(BaseModel):
70
+ seed: int = 0
71
+ num_inference_steps: int = 50
72
+ guidance_scale: float = 7.5
73
+ simplify: bool = True
74
+ target_face_num: int = DEFAULT_FACE_NUMBER
75
+
76
+ # Test endpoint
77
+ @app.get("/api/test")
78
+ async def test_endpoint():
79
+ return {"message": "FastAPI is running"}
80
+
81
  # triposg
82
  from image_process import prepare_image
83
  from briarmbg import BriaRMBG
 
121
  if not os.path.exists("checkpoints/big-lama.pt"):
122
  subprocess.run("wget -P checkpoints/ https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", shell=True, check=True)
123
 
 
 
 
 
 
 
 
 
 
124
  def get_random_hex():
125
  random_bytes = os.urandom(8)
126
  random_hex = random_bytes.hex()
127
  return random_hex
128
 
 
 
 
 
 
129
  @spaces.GPU(duration=180)
130
+ def run_full(image: str, req=None):
131
  seed = 0
132
  num_inference_steps = 50
133
  guidance_scale = 7.5
 
242
 
243
  return image_seg, mesh_path, textured_glb_path
244
 
245
+ # FastAPI endpoint for generating 3D models
246
+ @app.post("/api/generate")
247
+ async def generate_3d_model(request: GenerateRequest, image: UploadFile = File(...), api_key: str = Depends(verify_api_key)):
248
+ try:
249
+ # Save uploaded image to temporary directory
250
+ session_hash = get_random_hex()
251
+ save_dir = os.path.join(TMP_DIR, session_hash)
252
+ os.makedirs(save_dir, exist_ok=True)
253
+ image_path = os.path.join(save_dir, f"input_{get_random_hex()}.png")
254
+ with open(image_path, "wb") as f:
255
+ f.write(await image.read())
256
+
257
+ # Run the full pipeline
258
+ image_seg, mesh_path, textured_glb_path = run_full(image_path, req=None)
259
+
260
+ # Return the file URL for the textured GLB
261
+ file_url = f"/files/{session_hash}/{os.path.basename(textured_glb_path)}"
262
+ return {"file_url": file_url}
263
+ except Exception as e:
264
+ raise HTTPException(status_code=500, detail=str(e))
265
+ finally:
266
+ # Clean up temporary directory
267
+ if os.path.exists(save_dir):
268
+ shutil.rmtree(save_dir)
269
+
270
+ def start_session(req: gr.Request):
271
+ save_dir = os.path.join(TMP_DIR, str(req.session_hash))
272
+ os.makedirs(save_dir, exist_ok=True)
273
+ print("start session, mkdir", save_dir)
274
+
275
+ def end_session(req: gr.Request):
276
+ save_dir = os.path.join(TMP_DIR, str(req.session_hash))
277
+ shutil.rmtree(save_dir)
278
+
279
+ def get_random_seed(randomize_seed, seed):
280
+ if randomize_seed:
281
+ seed = random.randint(0, MAX_SEED)
282
+ return seed
283
+
284
  @spaces.GPU()
285
  @torch.no_grad()
286
  def run_segmentation(image: str):
 
385
 
386
  torch.cuda.empty_cache()
387
 
 
388
  mv_image_path = os.path.join(save_dir, f"polygenixai_mv_{get_random_hex()}.png")
389
  make_image_grid(images, rows=1).save(mv_image_path)
390
 
 
591
  demo.load(start_session)
592
  demo.unload(end_session)
593
 
594
+ # Mount Gradio to FastAPI
595
+ app = gr.mount_gradio_app(app, demo, path="/")
596
+
597
  if __name__ == "__main__":
598
+ import uvicorn
599
+ uvicorn.run(app, host="0.0.0.0", port=7860)