|
import torch |
|
import os |
|
from diffusers import DDPMScheduler |
|
from pipeline import Zero123PlusPipeline |
|
from huggingface_hub import hf_hub_download |
|
from PIL import Image |
|
|
|
|
|
def load_z123_pipe(device_number): |
|
device = torch.device( |
|
f"cuda:{device_number}" if torch.cuda.is_available() else "cpu" |
|
) |
|
|
|
pipeline = Zero123PlusPipeline.from_pretrained( |
|
"sudo-ai/zero123plus-v1.2", torch_dtype=torch.float16 |
|
) |
|
|
|
pipeline.scheduler = DDPMScheduler.from_config(pipeline.scheduler.config) |
|
|
|
unet_path = "ckpts/diffusion_pytorch_model.bin" |
|
|
|
if os.path.exists(unet_path): |
|
unet_ckpt_path = unet_path |
|
else: |
|
unet_ckpt_path = hf_hub_download( |
|
repo_id="TencentARC/InstantMesh", |
|
filename="diffusion_pytorch_model.bin", |
|
repo_type="model", |
|
) |
|
state_dict = torch.load(unet_ckpt_path, map_location="cpu") |
|
pipeline.unet.load_state_dict(state_dict, strict=True) |
|
|
|
pipeline.to(device) |
|
return pipeline |
|
|
|
|
|
def add_white_bg(image): |
|
|
|
if image.mode in ("RGBA", "LA"): |
|
|
|
white_bg = Image.new("RGB", image.size, (255, 255, 255)) |
|
|
|
white_bg.paste(image, mask=image.split()[-1]) |
|
return white_bg |
|
|
|
return image |
|
|