Spaces:
Running
on
A10G
Running
on
A10G
| import os | |
| import json | |
| import time | |
| import torch | |
| import random | |
| import inspect | |
| import argparse | |
| import numpy as np | |
| import pandas as pd | |
| from pathlib import Path | |
| from omegaconf import OmegaConf | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from diffusers import AutoencoderKL, DDIMScheduler | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from utils.unet import UNet3DConditionModel | |
| from utils.pipeline_magictime import MagicTimePipeline | |
| from utils.util import save_videos_grid | |
| from utils.util import load_weights | |
| def main(args): | |
| *_, func_args = inspect.getargvalues(inspect.currentframe()) | |
| func_args = dict(func_args) | |
| if 'counter' not in globals(): | |
| globals()['counter'] = 0 | |
| unique_id = globals()['counter'] | |
| globals()['counter'] += 1 | |
| savedir_base = f"{Path(args.config).stem}" | |
| savedir_prefix = "outputs" | |
| savedir = None | |
| if args.save_path: | |
| savedir = os.path.join(savedir_prefix, args.save_path, f"{savedir_base}-{unique_id}") | |
| else: | |
| savedir = os.path.join(savedir_prefix, f"{savedir_base}-{unique_id}") | |
| while os.path.exists(savedir): | |
| unique_id = globals()['counter'] | |
| globals()['counter'] += 1 | |
| if args.save_path: | |
| savedir = os.path.join(savedir_prefix, args.save_path, f"{savedir_base}-{unique_id}") | |
| else: | |
| savedir = os.path.join(savedir_prefix, f"{savedir_base}-{unique_id}") | |
| os.makedirs(savedir) | |
| print(f"The results will be save to {savedir}") | |
| model_config = OmegaConf.load(args.config)[0] | |
| inference_config = OmegaConf.load(args.config)[1] | |
| if model_config.magic_adapter_s_path: | |
| print("Use MagicAdapter-S") | |
| if model_config.magic_adapter_t_path: | |
| print("Use MagicAdapter-T") | |
| if model_config.magic_text_encoder_path: | |
| print("Use Magic_Text_Encoder") | |
| samples = [] | |
| # create validation pipeline | |
| tokenizer = CLIPTokenizer.from_pretrained(model_config.pretrained_model_path, subfolder="tokenizer") | |
| text_encoder = CLIPTextModel.from_pretrained(model_config.pretrained_model_path, subfolder="text_encoder").cuda() | |
| vae = AutoencoderKL.from_pretrained(model_config.pretrained_model_path, subfolder="vae").cuda() | |
| unet = UNet3DConditionModel.from_pretrained_2d(model_config.pretrained_model_path, subfolder="unet", | |
| unet_additional_kwargs=OmegaConf.to_container( | |
| inference_config.unet_additional_kwargs)).cuda() | |
| # set xformers | |
| if is_xformers_available() and (not args.without_xformers): | |
| unet.enable_xformers_memory_efficient_attention() | |
| pipeline = MagicTimePipeline( | |
| vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, | |
| scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), | |
| ).to("cuda") | |
| pipeline = load_weights( | |
| pipeline, | |
| motion_module_path=model_config.get("motion_module", ""), | |
| dreambooth_model_path=model_config.get("dreambooth_path", ""), | |
| magic_adapter_s_path=model_config.get("magic_adapter_s_path", ""), | |
| magic_adapter_t_path=model_config.get("magic_adapter_t_path", ""), | |
| magic_text_encoder_path=model_config.get("magic_text_encoder_path", ""), | |
| ).to("cuda") | |
| sample_idx = 0 | |
| if args.human: | |
| sample_idx = 0 # Initialize sample index | |
| while True: | |
| user_prompt = input("Enter your prompt (or type 'exit' to quit): ") | |
| if user_prompt.lower() == "exit": | |
| break | |
| random_seed = torch.randint(0, 2 ** 32 - 1, (1,)).item() | |
| torch.manual_seed(random_seed) | |
| print(f"current seed: {random_seed}") | |
| print(f"sampling {user_prompt} ...") | |
| # Now, you directly use `user_prompt` to generate a video. | |
| # The following is a placeholder call; you need to adapt it to your actual video generation function. | |
| sample = pipeline( | |
| user_prompt, | |
| num_inference_steps=model_config.steps, | |
| guidance_scale=model_config.guidance_scale, | |
| width=model_config.W, | |
| height=model_config.H, | |
| video_length=model_config.L, | |
| ).videos | |
| # Adapt the filename to avoid conflicts and properly represent the content | |
| prompt_for_filename = "-".join(user_prompt.replace("/", "").split(" ")[:10]) | |
| save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{random_seed}-{prompt_for_filename}.gif") | |
| print(f"save to {savedir}/sample/{sample_idx}-{random_seed}-{prompt_for_filename}.gif") | |
| sample_idx += 1 | |
| elif args.run_csv: | |
| print("run_csv") | |
| file_path = args.run_csv | |
| data = pd.read_csv(file_path) | |
| for index, row in data.iterrows(): | |
| user_prompt = row['name'] # Set the user_prompt to the 'name' field of the current row | |
| videoid = row['videoid'] # Extract videoid for filename | |
| random_seed = torch.randint(0, 2 ** 32 - 1, (1,)).item() | |
| torch.manual_seed(random_seed) | |
| print(f"current seed: {random_seed}") | |
| print(f"sampling {user_prompt} ...") | |
| sample = pipeline( | |
| user_prompt, | |
| num_inference_steps=model_config.steps, | |
| guidance_scale=model_config.guidance_scale, | |
| width=model_config.W, | |
| height=model_config.H, | |
| video_length=model_config.L, | |
| ).videos | |
| # Adapt the filename to avoid conflicts and properly represent the content | |
| save_videos_grid(sample, f"{savedir}/sample/{videoid}.gif") | |
| print(f"save to {savedir}/sample/{videoid}.gif") | |
| elif args.run_json: | |
| print("run_json") | |
| file_path = args.run_json | |
| with open(file_path, 'r') as file: | |
| data = json.load(file) | |
| prompts = [] | |
| videoids = [] | |
| senids = [] | |
| for item in data: | |
| prompts.append(item['caption']) | |
| videoids.append(item['video_id']) | |
| senids.append(item['sen_id']) | |
| n_prompts = list(model_config.n_prompt) * len(prompts) if len( | |
| model_config.n_prompt) == 1 else model_config.n_prompt | |
| random_seeds = model_config.get("seed", [-1]) | |
| random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) | |
| random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds | |
| model_config.random_seed = [] | |
| for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)): | |
| filename = f"MSRVTT/sample/{videoids[prompt_idx]}-{senids[prompt_idx]}.gif" | |
| if os.path.exists(filename): | |
| print(f"File {filename} already exists, skipping...") | |
| continue | |
| # manually set random seed for reproduction | |
| if random_seed != -1: | |
| torch.manual_seed(random_seed) | |
| else: | |
| torch.seed() | |
| model_config.random_seed.append(torch.initial_seed()) | |
| print(f"current seed: {torch.initial_seed()}") | |
| print(f"sampling {prompt} ...") | |
| sample = pipeline( | |
| prompt, | |
| num_inference_steps=model_config.steps, | |
| guidance_scale=model_config.guidance_scale, | |
| width=model_config.W, | |
| height=model_config.H, | |
| video_length=model_config.L, | |
| ).videos | |
| # Adapt the filename to avoid conflicts and properly represent the content | |
| save_videos_grid(sample, filename) | |
| print(f"save to {filename}") | |
| else: | |
| prompts = model_config.prompt | |
| n_prompts = list(model_config.n_prompt) * len(prompts) if len( | |
| model_config.n_prompt) == 1 else model_config.n_prompt | |
| random_seeds = model_config.get("seed", [-1]) | |
| random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) | |
| random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds | |
| model_config.random_seed = [] | |
| for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)): | |
| # manually set random seed for reproduction | |
| if random_seed != -1: | |
| torch.manual_seed(random_seed) | |
| np.random.seed(random_seed) | |
| random.seed(random_seed) | |
| else: | |
| torch.seed() | |
| model_config.random_seed.append(torch.initial_seed()) | |
| print(f"current seed: {torch.initial_seed()}") | |
| print(f"sampling {prompt} ...") | |
| sample = pipeline( | |
| prompt, | |
| negative_prompt=n_prompt, | |
| num_inference_steps=model_config.steps, | |
| guidance_scale=model_config.guidance_scale, | |
| width=model_config.W, | |
| height=model_config.H, | |
| video_length=model_config.L, | |
| ).videos | |
| samples.append(sample) | |
| prompt = "-".join((prompt.replace("/", "").split(" ")[:10])) | |
| save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{random_seed}-{prompt}.gif") | |
| print(f"save to {savedir}/sample/{random_seed}-{prompt}.gif") | |
| sample_idx += 1 | |
| samples = torch.concat(samples) | |
| save_videos_grid(samples, f"{savedir}/merge_all.gif", n_rows=4) | |
| OmegaConf.save(model_config, f"{savedir}/model_config.yaml") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", type=str, required=True) | |
| parser.add_argument("--without-xformers", action="store_true") | |
| parser.add_argument("--human", action="store_true", help="Enable human mode for interactive video generation") | |
| parser.add_argument("--run-csv", type=str, default=None) | |
| parser.add_argument("--run-json", type=str, default=None) | |
| parser.add_argument("--save-path", type=str, default=None) | |
| args = parser.parse_args() | |
| main(args) | |