File size: 1,537 Bytes
a176955
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import os
from diffusers import DDPMScheduler
from pipeline import Zero123PlusPipeline
from huggingface_hub import hf_hub_download
from PIL import Image


def load_z123_pipe(device_number):
    device = torch.device(
        f"cuda:{device_number}" if torch.cuda.is_available() else "cpu"
    )

    pipeline = Zero123PlusPipeline.from_pretrained(
        "sudo-ai/zero123plus-v1.2", torch_dtype=torch.float16
    )
    # DDPM supports custom timesteps
    pipeline.scheduler = DDPMScheduler.from_config(pipeline.scheduler.config)

    unet_path = "ckpts/diffusion_pytorch_model.bin"
    # load custom white-background UNet
    if os.path.exists(unet_path):
        unet_ckpt_path = unet_path
    else:
        unet_ckpt_path = hf_hub_download(
            repo_id="TencentARC/InstantMesh",
            filename="diffusion_pytorch_model.bin",
            repo_type="model",
        )
    state_dict = torch.load(unet_ckpt_path, map_location="cpu")
    pipeline.unet.load_state_dict(state_dict, strict=True)

    pipeline.to(device)
    return pipeline


def add_white_bg(image):
    # Check if image has transparency (RGBA or LA mode)
    if image.mode in ("RGBA", "LA"):
        # Create a white background image of the same size
        white_bg = Image.new("RGB", image.size, (255, 255, 255))
        # Paste original image onto white background using alpha channel as mask
        white_bg.paste(image, mask=image.split()[-1])
        return white_bg
    # If no transparency, return the original image
    return image