etTripoSg-api-gradio / scripts /inference_triposg_scribble.py
staswrs
octree_depth all 9
b3806ac
raw
history blame
2.81 kB
import argparse
import os
import sys
from glob import glob
from typing import Any, Union
import numpy as np
import torch
import trimesh
from huggingface_hub import snapshot_download
from PIL import Image
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from triposg.pipelines.pipeline_triposg_scribble import TripoSGScribblePipeline
@torch.no_grad()
def run_triposg_scribble(
pipe: Any,
image_input: Union[str, Image.Image],
prompt: str,
seed: int,
num_inference_steps: int = 16,
scribble_confidence: float = 0.4,
prompt_confidence: float = 1.0
) -> trimesh.Scene:
img_pil = Image.open(image_input).convert("RGB")
outputs = pipe(
image=img_pil,
prompt=prompt,
generator=torch.Generator(device=pipe.device).manual_seed(seed),
num_inference_steps=num_inference_steps,
guidance_scale=0, # this is a CFG-distilled model
attention_kwargs={"cross_attention_scale": prompt_confidence, "cross_attention_2_scale": scribble_confidence},
use_flash_decoder=False, # there're some boundary problems when using flash decoder with this model
dense_octree_depth=9, hierarchical_octree_depth=9 # 256 resolution for faster inference
).samples[0]
mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
return mesh
if __name__ == "__main__":
device = "cuda"
dtype = torch.float16
parser = argparse.ArgumentParser()
parser.add_argument("--image-input", type=str, required=True)
parser.add_argument("--prompt", type=str, required=True)
parser.add_argument("--output-path", type=str, default="./output.glb")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--num-inference-steps", type=int, default=16)
# feel free to tune the scribble confidence, 0.3-0.5 often gives good results for hand-drawn sketches
parser.add_argument("--scribble-conf", type=float, default=0.4)
parser.add_argument("--prompt-conf", type=float, default=1.0)
args = parser.parse_args()
# download pretrained weights
triposg_scribble_weights_dir = "pretrained_weights/TripoSG-scribble"
snapshot_download(repo_id="VAST-AI/TripoSG-scribble", local_dir=triposg_scribble_weights_dir)
# init tripoSG pipeline
pipe: TripoSGScribblePipeline = TripoSGScribblePipeline.from_pretrained(triposg_scribble_weights_dir).to(device, dtype)
# run inference
run_triposg_scribble(
pipe,
image_input=args.image_input,
prompt=args.prompt,
seed=args.seed,
num_inference_steps=args.num_inference_steps,
scribble_confidence=args.scribble_conf,
prompt_confidence=args.prompt_conf,
).export(args.output_path)
print("Mesh saved to", args.output_path)