Spaces:
Runtime error
Runtime error
File size: 8,542 Bytes
7d65c03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import os
import time
from omegaconf import OmegaConf
import torch
from scripts.evaluation.funcs import load_model_checkpoint, save_videos, batch_ddim_sampling, get_latent_z
from utils.utils import instantiate_from_config
from huggingface_hub import hf_hub_download
from einops import repeat
import torchvision.transforms as transforms
from pytorch_lightning import seed_everything
from einops import rearrange
import argparse
import glob
from PIL import Image
import numpy as np
from moviepy.editor import VideoFileClip, concatenate_videoclips
class Image2Video():
def __init__(self, result_dir='./tmp/', gpu_num=1, resolution='256_256') -> None:
self.resolution = (int(resolution.split('_')[0]), int(resolution.split('_')[1])) #hw
self.download_model()
self.result_dir = result_dir
if not os.path.exists(self.result_dir):
os.mkdir(self.result_dir)
ckpt_path='checkpoints/tooncrafter_'+resolution.split('_')[1]+'_interp_v1/model.ckpt'
config_file='configs/inference_'+resolution.split('_')[1]+'_v1.0.yaml'
config = OmegaConf.load(config_file)
model_config = config.pop("model", OmegaConf.create())
model_config['params']['unet_config']['params']['use_checkpoint']=False
model_list = []
for gpu_id in range(gpu_num):
model = instantiate_from_config(model_config)
print(ckpt_path)
assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, ckpt_path)
model.eval()
model_list.append(model)
self.model_list = model_list
self.save_fps = 8
def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, image2=None):
img_name = ""
if type(image) == type(""):
img_name = os.path.basename(image).split('.')[0]
image = np.asarray(Image.open(image))
if type(image2) == type(""):
image2 = np.asarray(Image.open(image2))
seed_everything(seed)
transform = transforms.Compose([
transforms.Resize(min(self.resolution)),
transforms.CenterCrop(self.resolution),
])
torch.cuda.empty_cache()
print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
start = time.time()
gpu_id=0
if steps > 60:
steps = 60
model = self.model_list[gpu_id]
model = model.half().cuda()
batch_size=1
channels = model.model.diffusion_model.out_channels
frames = model.temporal_length
h, w = self.resolution[0] // 8, self.resolution[1] // 8
noise_shape = [batch_size, channels, frames, h, w]
with torch.no_grad(), torch.cuda.amp.autocast():
text_emb = model.get_learned_conditioning([prompt])
img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().half().to(model.device)
img_tensor = (img_tensor / 255. - 0.5) * 2
image_tensor_resized = transform(img_tensor) #3,h,w
videos = image_tensor_resized.unsqueeze(0).unsqueeze(2) # bc1hw
videos = repeat(videos, 'b c t h w -> b c (repeat t) h w', repeat=frames//2)
if image2 is not None:
img_tensor2 = torch.from_numpy(image2).permute(2, 0, 1).float().half().to(model.device)
img_tensor2 = (img_tensor2 / 255. - 0.5) * 2
image_tensor_resized2 = transform(img_tensor2) #3,h,w
videos2 = image_tensor_resized2.unsqueeze(0).unsqueeze(2) # bchw
videos2 = repeat(videos2, 'b c t h w -> b c (repeat t) h w', repeat=frames//2)
videos = torch.cat([videos, videos2], dim=2)
z, hs = self.get_latent_z_with_hidden_states(model, videos)
img_tensor_repeat = torch.zeros_like(z)
img_tensor_repeat[:,:,:1,:,:] = z[:,:,:1,:,:]
img_tensor_repeat[:,:,-1:,:,:] = z[:,:,-1:,:,:]
cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
img_emb = model.image_proj_model(cond_images)
imtext_cond = torch.cat([text_emb, img_emb], dim=1)
fs = torch.tensor([fs], dtype=torch.long, device=model.device)
cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale, hs=hs)
if image2 is None:
batch_samples = batch_samples[:,:,:,:-1,...]
prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
prompt_str=prompt_str[:40]
if len(prompt_str) == 0:
prompt_str = 'empty_prompt'
# 使用 img_path 的名称来命名视频
#img_name = os.path.basename(image).split('.')[0]
video_filename = f"{img_name}"
save_videos(batch_samples, self.result_dir, filenames=[video_filename], fps=self.save_fps)
print(f"Saved in {video_filename}. Time used: {(time.time() - start):.2f} seconds")
model = model.cpu()
video_filename += ".mp4"
return os.path.join(self.result_dir, video_filename)
def download_model(self):
REPO_ID = 'Doubiiu/ToonCrafter'
filename_list = ['model.ckpt']
if not os.path.exists('./checkpoints/tooncrafter_'+str(self.resolution[1])+'_interp_v1/'):
os.makedirs('./checkpoints/tooncrafter_'+str(self.resolution[1])+'_interp_v1/')
for filename in filename_list:
local_file = os.path.join('./checkpoints/tooncrafter_'+str(self.resolution[1])+'_interp_v1/', filename)
if not os.path.exists(local_file):
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/tooncrafter_'+str(self.resolution[1])+'_interp_v1/', local_dir_use_symlinks=False)
def get_latent_z_with_hidden_states(self, model, videos):
b, c, t, h, w = videos.shape
x = rearrange(videos, 'b c t h w -> (b t) c h w')
encoder_posterior, hidden_states = model.first_stage_model.encode(x, return_hidden_states=True)
hidden_states_first_last = []
for hid in hidden_states:
hid = rearrange(hid, '(b t) c h w -> b c t h w', t=t)
hid_new = torch.cat([hid[:, :, 0:1], hid[:, :, -1:]], dim=2)
hidden_states_first_last.append(hid_new)
z = model.get_first_stage_encoding(encoder_posterior).detach()
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
return z, hidden_states_first_last
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Image to Video Conversion')
parser.add_argument('--image_dir', type=str, required=True, help='Path to the directory containing input images')
parser.add_argument('--prompt', type=str, required=True, help='Prompt for the video')
parser.add_argument('--steps', type=int, default=50, help='Number of steps')
parser.add_argument('--cfg_scale', type=float, default=7.5, help='CFG scale')
parser.add_argument('--eta', type=float, default=1.0, help='Eta value')
parser.add_argument('--fs', type=int, default=3, help='FS value')
parser.add_argument('--seed', type=int, default=123, help='Seed value')
args = parser.parse_args()
i2v = Image2Video("results" ,resolution = "320_512")
image_paths = sorted(glob.glob(os.path.join(args.image_dir, '*.png')))
video_paths = []
for i in range(len(image_paths) - 1):
img_path = image_paths[i]
img2_path = image_paths[i + 1]
video_path = i2v.get_image(img_path, args.prompt, args.steps, args.cfg_scale, args.eta, args.fs, args.seed, img2_path)
video_paths.append(video_path)
print('done', video_path)
# 使用第一个图像的名称来命名最终视频
first_image_name = os.path.basename(image_paths[0]).split('.')[0]
final_video_path = os.path.join(i2v.result_dir, f"{first_image_name}_final.mp4")
# 顺次连接所有生成的视频
clips = [VideoFileClip(vp) for vp in video_paths]
final_clip = concatenate_videoclips(clips, method="compose")
final_clip.write_videofile(final_video_path, codec="libx264", fps=i2v.save_fps)
print(f"Final video saved at {final_video_path}") |