PolyGenixAI6.0 / api.py
anvilinteractiv's picture
Upload 2 files
3a9b68a verified
raw
history blame
9.25 kB
import spaces
import os
import numpy as np
import torch
from PIL import Image
import trimesh
import random
from transformers import AutoModelForImageSegmentation
from torchvision import transforms
from huggingface_hub import hf_hub_download, snapshot_download
import subprocess
import shutil
from fastapi import FastAPI, HTTPException, Depends, File, UploadFile
from fastapi.security import APIKeyHeader
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import uvicorn
# Install additional dependencies
subprocess.run("pip install spandrel==0.4.1 --no-deps", shell=True, check=True)
subprocess.run("pip install fastapi uvicorn", shell=True, check=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16
print("DEVICE: ", DEVICE)
DEFAULT_FACE_NUMBER = 100000
MAX_SEED = np.iinfo(np.int32).max
TRIPOSG_REPO_URL = "https://github.com/VAST-AI-Research/TripoSG.git"
MV_ADAPTER_REPO_URL = "https://github.com/huanngzh/MV-Adapter.git"
RMBG_PRETRAINED_MODEL = "checkpoints/RMBG-1.4"
TRIPOSG_PRETRAINED_MODEL = "checkpoints/TripoSG"
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
os.makedirs(TMP_DIR, exist_ok=True)
TRIPOSG_CODE_DIR = "./triposg"
if not os.path.exists(TRIPOSG_CODE_DIR):
os.system(f"git clone {TRIPOSG_REPO_URL} {TRIPOSG_CODE_DIR}")
MV_ADAPTER_CODE_DIR = "./mv_adapter"
if not os.path.exists(MV_ADAPTER_CODE_DIR):
os.system(f"git clone {MV_ADAPTER_REPO_URL} {MV_ADAPTER_CODE_DIR} && cd {MV_ADAPTER_CODE_DIR} && git checkout 7d37a97e9bc223cdb8fd26a76bd8dd46504c7c3d")
import sys
sys.path.append(TRIPOSG_CODE_DIR)
sys.path.append(os.path.join(TRIPOSG_CODE_DIR, "scripts"))
sys.path.append(MV_ADAPTER_CODE_DIR)
sys.path.append(os.path.join(MV_ADAPTER_CODE_DIR, "scripts"))
# triposg
from image_process import prepare_image
from briarmbg import BriaRMBG
snapshot_download("briaai/RMBG-1.4", local_dir=RMBG_PRETRAINED_MODEL)
rmbg_net = BriaRMBG.from_pretrained(RMBG_PRETRAINED_MODEL).to(DEVICE)
rmbg_net.eval()
from triposg.pipelines.pipeline_triposg import TripoSGPipeline
snapshot_download("VAST-AI/TripoSG", local_dir=TRIPOSG_PRETRAINED_MODEL)
triposg_pipe = TripoSGPipeline.from_pretrained(TRIPOSG_PRETRAINED_MODEL).to(DEVICE, DTYPE)
# mv adapter
NUM_VIEWS = 6
from inference_ig2mv_sdxl import prepare_pipeline, preprocess_image, remove_bg
from mvadapter.utils import get_orthogonal_camera, tensor_to_image, make_image_grid
from mvadapter.utils.render import NVDiffRastContextWrapper, load_mesh, render
mv_adapter_pipe = prepare_pipeline(
base_model="stabilityai/stable-diffusion-xl-base-1.0",
vae_model="madebyollin/sdxl-vae-fp16-fix",
unet_model=None,
lora_model=None,
adapter_path="huanngzh/mv-adapter",
scheduler=None,
num_views=NUM_VIEWS,
device=DEVICE,
dtype=torch.float16,
)
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
).to(DEVICE)
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, DEVICE)
if not os.path.exists("checkpoints/RealESRGAN_x2plus.pth"):
hf_hub_download("dtarnow/UPscaler", filename="RealESRGAN_x2plus.pth", local_dir="checkpoints")
if not os.path.exists("checkpoints/big-lama.pt"):
subprocess.run("wget -P checkpoints/ https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", shell=True, check=True)
# Initialize FastAPI app
app = FastAPI()
# Mount static files for serving generated models
app.mount("/files", StaticFiles(directory=TMP_DIR), name="files")
# API key authentication
api_key_header = APIKeyHeader(name="X-API-Key")
VALID_API_KEY = os.getenv("POLYGENIX_API_KEY", "your-secret-api-key")
async def verify_api_key(api_key: str = Depends(api_key_header)):
if api_key != VALID_API_KEY:
raise HTTPException(status_code=401, detail="Invalid API key")
return api_key
# API request model
class GenerateRequest(BaseModel):
seed: int = 0
num_inference_steps: int = 50
guidance_scale: float = 7.5
simplify: bool = True
target_face_num: int = DEFAULT_FACE_NUMBER
# Test endpoint
@app.get("/api/test")
async def test_endpoint():
return {"message": "FastAPI is running"}
def get_random_hex():
random_bytes = os.urandom(8)
random_hex = random_bytes.hex()
return random_hex
@spaces.GPU(duration=180)
def run_full(image: str, req=None):
seed = 0
num_inference_steps = 50
guidance_scale = 7.5
simplify = True
target_face_num = DEFAULT_FACE_NUMBER
image_seg = prepare_image(image, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
outputs = triposg_pipe(
image=image_seg,
generator=torch.Generator(device=triposg_pipe.device).manual_seed(seed),
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale
).samples[0]
print("mesh extraction done")
mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
if simplify:
print("start simplify")
from utils import simplify_mesh
mesh = simplify_mesh(mesh, target_face_num)
save_dir = os.path.join(TMP_DIR, "examples")
os.makedirs(save_dir, exist_ok=True)
mesh_path = os.path.join(save_dir, f"polygenixai_{get_random_hex()}.glb")
mesh.export(mesh_path)
print("save to ", mesh_path)
torch.cuda.empty_cache()
height, width = 768, 768
cameras = get_orthogonal_camera(
elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
distance=[1.8] * NUM_VIEWS,
left=-0.55,
right=0.55,
bottom=-0.55,
top=0.55,
azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
device=DEVICE,
)
ctx = NVDiffRastContextWrapper(device=DEVICE, context_type="cuda")
mesh = load_mesh(mesh_path, rescale=True, device=DEVICE)
render_out = render(
ctx,
mesh,
cameras,
height=height,
width=width,
render_attr=False,
normal_background=0.0,
)
control_images = (
torch.cat(
[
(render_out.pos + 0.5).clamp(0, 1),
(render_out.normal / 2 + 0.5).clamp(0, 1),
],
dim=-1,
)
.permute(0, 3, 1, 2)
.to(DEVICE)
)
image = Image.open(image)
image = remove_bg_fn(image)
image = preprocess_image(image, height, width)
pipe_kwargs = {}
if seed != -1 and isinstance(seed, int):
pipe_kwargs["generator"] = torch.Generator(device=DEVICE).manual_seed(seed)
images = mv_adapter_pipe(
"high quality",
height=height,
width=width,
num_inference_steps=15,
guidance_scale=3.0,
num_images_per_prompt=NUM_VIEWS,
control_image=control_images,
control_conditioning_scale=1.0,
reference_image=image,
reference_conditioning_scale=1.0,
negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
cross_attention_kwargs={"scale": 1.0},
**pipe_kwargs,
).images
torch.cuda.empty_cache()
mv_image_path = os.path.join(save_dir, f"polygenixai_mv_{get_random_hex()}.png")
make_image_grid(images, rows=1).save(mv_image_path)
from texture import TexturePipeline, ModProcessConfig
texture_pipe = TexturePipeline(
upscaler_ckpt_path="checkpoints/RealESRGAN_x2plus.pth",
inpaint_ckpt_path="checkpoints/big-lama.pt",
device=DEVICE,
)
textured_glb_path = texture_pipe(
mesh_path=mesh_path,
save_dir=save_dir,
save_name=f"polygenixai_texture_mesh_{get_random_hex()}.glb",
uv_unwarp=True,
uv_size=4096,
rgb_path=mv_image_path,
rgb_process_config=ModProcessConfig(view_upscale=True, inpaint_mode="view"),
camera_azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
)
return image_seg, mesh_path, textured_glb_path
# FastAPI endpoint for generating 3D models
@app.post("/api/generate")
async def generate_3d_model(request: GenerateRequest, image: UploadFile = File(...), api_key: str = Depends(verify_api_key)):
try:
# Save uploaded image to temporary directory
session_hash = get_random_hex()
save_dir = os.path.join(TMP_DIR, session_hash)
os.makedirs(save_dir, exist_ok=True)
image_path = os.path.join(save_dir, f"input_{get_random_hex()}.png")
with open(image_path, "wb") as f:
f.write(await image.read())
# Run the full pipeline
image_seg, mesh_path, textured_glb_path = run_full(image_path, req=None)
# Return the file URL for the textured GLB
file_url = f"/files/{session_hash}/{os.path.basename(textured_glb_path)}"
return {"file_url": file_url}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
# Clean up temporary directory
if os.path.exists(save_dir):
shutil.rmtree(save_dir)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)