Spaces:
Sleeping
Sleeping
File size: 2,810 Bytes
b2a27a7 b3806ac b2a27a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import argparse
import os
import sys
from glob import glob
from typing import Any, Union
import numpy as np
import torch
import trimesh
from huggingface_hub import snapshot_download
from PIL import Image
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from triposg.pipelines.pipeline_triposg_scribble import TripoSGScribblePipeline
@torch.no_grad()
def run_triposg_scribble(
pipe: Any,
image_input: Union[str, Image.Image],
prompt: str,
seed: int,
num_inference_steps: int = 16,
scribble_confidence: float = 0.4,
prompt_confidence: float = 1.0
) -> trimesh.Scene:
img_pil = Image.open(image_input).convert("RGB")
outputs = pipe(
image=img_pil,
prompt=prompt,
generator=torch.Generator(device=pipe.device).manual_seed(seed),
num_inference_steps=num_inference_steps,
guidance_scale=0, # this is a CFG-distilled model
attention_kwargs={"cross_attention_scale": prompt_confidence, "cross_attention_2_scale": scribble_confidence},
use_flash_decoder=False, # there're some boundary problems when using flash decoder with this model
dense_octree_depth=9, hierarchical_octree_depth=9 # 256 resolution for faster inference
).samples[0]
mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
return mesh
if __name__ == "__main__":
device = "cuda"
dtype = torch.float16
parser = argparse.ArgumentParser()
parser.add_argument("--image-input", type=str, required=True)
parser.add_argument("--prompt", type=str, required=True)
parser.add_argument("--output-path", type=str, default="./output.glb")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--num-inference-steps", type=int, default=16)
# feel free to tune the scribble confidence, 0.3-0.5 often gives good results for hand-drawn sketches
parser.add_argument("--scribble-conf", type=float, default=0.4)
parser.add_argument("--prompt-conf", type=float, default=1.0)
args = parser.parse_args()
# download pretrained weights
triposg_scribble_weights_dir = "pretrained_weights/TripoSG-scribble"
snapshot_download(repo_id="VAST-AI/TripoSG-scribble", local_dir=triposg_scribble_weights_dir)
# init tripoSG pipeline
pipe: TripoSGScribblePipeline = TripoSGScribblePipeline.from_pretrained(triposg_scribble_weights_dir).to(device, dtype)
# run inference
run_triposg_scribble(
pipe,
image_input=args.image_input,
prompt=args.prompt,
seed=args.seed,
num_inference_steps=args.num_inference_steps,
scribble_confidence=args.scribble_conf,
prompt_confidence=args.prompt_conf,
).export(args.output_path)
print("Mesh saved to", args.output_path)
|