File size: 2,810 Bytes
b2a27a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3806ac
b2a27a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import argparse
import os
import sys
from glob import glob
from typing import Any, Union

import numpy as np
import torch
import trimesh
from huggingface_hub import snapshot_download
from PIL import Image

sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from triposg.pipelines.pipeline_triposg_scribble import TripoSGScribblePipeline



@torch.no_grad()
def run_triposg_scribble(
    pipe: Any,
    image_input: Union[str, Image.Image],
    prompt: str,
    seed: int,
    num_inference_steps: int = 16,
    scribble_confidence: float = 0.4,
    prompt_confidence: float = 1.0
) -> trimesh.Scene:

    img_pil = Image.open(image_input).convert("RGB")

    outputs = pipe(
        image=img_pil,
        prompt=prompt,
        generator=torch.Generator(device=pipe.device).manual_seed(seed),
        num_inference_steps=num_inference_steps,
        guidance_scale=0, # this is a CFG-distilled model
        attention_kwargs={"cross_attention_scale": prompt_confidence, "cross_attention_2_scale": scribble_confidence},
        use_flash_decoder=False, # there're some boundary problems when using flash decoder with this model
        dense_octree_depth=9, hierarchical_octree_depth=9 # 256 resolution for faster inference
    ).samples[0]
    mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))    
    return mesh


if __name__ == "__main__":
    device = "cuda"
    dtype = torch.float16

    parser = argparse.ArgumentParser()
    parser.add_argument("--image-input", type=str, required=True)
    parser.add_argument("--prompt", type=str, required=True)
    parser.add_argument("--output-path", type=str, default="./output.glb")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--num-inference-steps", type=int, default=16)
    # feel free to tune the scribble confidence, 0.3-0.5 often gives good results for hand-drawn sketches
    parser.add_argument("--scribble-conf", type=float, default=0.4)
    parser.add_argument("--prompt-conf", type=float, default=1.0)
    args = parser.parse_args()

    # download pretrained weights
    triposg_scribble_weights_dir = "pretrained_weights/TripoSG-scribble"
    snapshot_download(repo_id="VAST-AI/TripoSG-scribble", local_dir=triposg_scribble_weights_dir)

    # init tripoSG pipeline
    pipe: TripoSGScribblePipeline = TripoSGScribblePipeline.from_pretrained(triposg_scribble_weights_dir).to(device, dtype)

    # run inference
    run_triposg_scribble(
        pipe,
        image_input=args.image_input,
        prompt=args.prompt,
        seed=args.seed,
        num_inference_steps=args.num_inference_steps,
        scribble_confidence=args.scribble_conf,
        prompt_confidence=args.prompt_conf,
    ).export(args.output_path)
    print("Mesh saved to", args.output_path)