Update device handling across multiple modules to support automatic selection of CUDA or CPU based on availability. This change enhances compatibility and performance on systems with or without GPU support, ensuring consistent behavior in model loading and data processing.
597a667
import torch | |
from libs.base_utils import do_resize_content | |
from imagedream.ldm.util import ( | |
instantiate_from_config, | |
get_obj_from_str, | |
) | |
from omegaconf import OmegaConf | |
from PIL import Image | |
import numpy as np | |
class TwoStagePipeline(object): | |
def __init__( | |
self, | |
stage1_model_config, | |
stage2_model_config, | |
stage1_sampler_config, | |
stage2_sampler_config, | |
device="cuda", | |
dtype=torch.float16, | |
resize_rate=1, | |
) -> None: | |
""" | |
only for two stage generate process. | |
- the first stage was condition on single pixel image, gererate multi-view pixel image, based on the v2pp config | |
- the second stage was condition on multiview pixel image generated by the first stage, generate the final image, based on the stage2-test config | |
""" | |
device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
self.resize_rate = resize_rate | |
self.stage1_model = instantiate_from_config(OmegaConf.load(stage1_model_config.config).model) | |
self.stage1_model.load_state_dict(torch.load(stage1_model_config.resume, map_location=device), strict=False) | |
self.stage1_model = self.stage1_model.to(device).to(dtype) | |
self.stage2_model = instantiate_from_config(OmegaConf.load(stage2_model_config.config).model) | |
sd = torch.load(stage2_model_config.resume, map_location=device) | |
self.stage2_model.load_state_dict(sd, strict=False) | |
self.stage2_model = self.stage2_model.to(device).to(dtype) | |
self.stage1_model.device = device | |
self.stage2_model.device = device | |
self.device = device | |
self.dtype = dtype | |
self.stage1_sampler = get_obj_from_str(stage1_sampler_config.target)( | |
self.stage1_model, device=device, dtype=dtype, **stage1_sampler_config.params | |
) | |
self.stage2_sampler = get_obj_from_str(stage2_sampler_config.target)( | |
self.stage2_model, device=device, dtype=dtype, **stage2_sampler_config.params | |
) | |
def stage1_sample( | |
self, | |
pixel_img, | |
prompt="3D assets", | |
neg_texts="uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear.", | |
step=50, | |
scale=5, | |
ddim_eta=0.0, | |
): | |
if type(pixel_img) == str: | |
pixel_img = Image.open(pixel_img) | |
if isinstance(pixel_img, Image.Image): | |
if pixel_img.mode == "RGBA": | |
background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0)) | |
pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB") | |
else: | |
pixel_img = pixel_img.convert("RGB") | |
else: | |
raise | |
uc = self.stage1_sampler.model.get_learned_conditioning([neg_texts]).to(self.device) | |
stage1_images = self.stage1_sampler.i2i( | |
self.stage1_sampler.model, | |
self.stage1_sampler.size, | |
prompt, | |
uc=uc, | |
sampler=self.stage1_sampler.sampler, | |
ip=pixel_img, | |
step=step, | |
scale=scale, | |
batch_size=self.stage1_sampler.batch_size, | |
ddim_eta=ddim_eta, | |
dtype=self.stage1_sampler.dtype, | |
device=self.stage1_sampler.device, | |
camera=self.stage1_sampler.camera, | |
num_frames=self.stage1_sampler.num_frames, | |
pixel_control=(self.stage1_sampler.mode == "pixel"), | |
transform=self.stage1_sampler.image_transform, | |
offset_noise=self.stage1_sampler.offset_noise, | |
) | |
stage1_images = [Image.fromarray(img) for img in stage1_images] | |
stage1_images.pop(self.stage1_sampler.ref_position) | |
return stage1_images | |
def stage2_sample(self, pixel_img, stage1_images, scale=5, step=50): | |
if type(pixel_img) == str: | |
pixel_img = Image.open(pixel_img) | |
if isinstance(pixel_img, Image.Image): | |
if pixel_img.mode == "RGBA": | |
background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0)) | |
pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB") | |
else: | |
pixel_img = pixel_img.convert("RGB") | |
else: | |
raise | |
stage2_images = self.stage2_sampler.i2iStage2( | |
self.stage2_sampler.model, | |
self.stage2_sampler.size, | |
"3D assets", | |
self.stage2_sampler.uc, | |
self.stage2_sampler.sampler, | |
pixel_images=stage1_images, | |
ip=pixel_img, | |
step=step, | |
scale=scale, | |
batch_size=self.stage2_sampler.batch_size, | |
ddim_eta=0.0, | |
dtype=self.stage2_sampler.dtype, | |
device=self.stage2_sampler.device, | |
camera=self.stage2_sampler.camera, | |
num_frames=self.stage2_sampler.num_frames, | |
pixel_control=(self.stage2_sampler.mode == "pixel"), | |
transform=self.stage2_sampler.image_transform, | |
offset_noise=self.stage2_sampler.offset_noise, | |
) | |
stage2_images = [Image.fromarray(img) for img in stage2_images] | |
return stage2_images | |
def set_seed(self, seed): | |
self.stage1_sampler.seed = seed | |
self.stage2_sampler.seed = seed | |
def __call__(self, pixel_img, prompt="3D assets", scale=5, step=50): | |
pixel_img = do_resize_content(pixel_img, self.resize_rate) | |
stage1_images = self.stage1_sample(pixel_img, prompt, scale=scale, step=step) | |
stage2_images = self.stage2_sample(pixel_img, stage1_images, scale=scale, step=step) | |
return { | |
"ref_img": pixel_img, | |
"stage1_images": stage1_images, | |
"stage2_images": stage2_images, | |
} | |
if __name__ == "__main__": | |
stage1_config = OmegaConf.load("configs/nf7_v3_SNR_rd_size_stroke.yaml").config | |
stage2_config = OmegaConf.load("configs/stage2-v2-snr.yaml").config | |
stage2_sampler_config = stage2_config.sampler | |
stage1_sampler_config = stage1_config.sampler | |
stage1_model_config = stage1_config.models | |
stage2_model_config = stage2_config.models | |
pipeline = TwoStagePipeline( | |
stage1_model_config, | |
stage2_model_config, | |
stage1_sampler_config, | |
stage2_sampler_config, | |
) | |
img = Image.open("assets/astronaut.png") | |
rt_dict = pipeline(img) | |
stage1_images = rt_dict["stage1_images"] | |
stage2_images = rt_dict["stage2_images"] | |
np_imgs = np.concatenate(stage1_images, 1) | |
np_xyzs = np.concatenate(stage2_images, 1) | |
Image.fromarray(np_imgs).save("pixel_images.png") | |
Image.fromarray(np_xyzs).save("xyz_images.png") | |