Spaces:
Paused
Paused
""" | |
This is a parallel inference script for CogVideo. The original script | |
can be found from the xDiT project at | |
https://github.com/xdit-project/xDiT/blob/main/examples/cogvideox_example.py | |
By using this code, the inference process is parallelized on multiple GPUs, | |
and thus speeded up. | |
Usage: | |
1. pip install xfuser | |
2. mkdir results | |
3. run the following command to generate video | |
torchrun --nproc_per_node=4 parallel_inference_xdit.py \ | |
--model <cogvideox-model-path> --ulysses_degree 1 --ring_degree 2 \ | |
--use_cfg_parallel --height 480 --width 720 --num_frames 9 \ | |
--prompt 'A small dog.' | |
You can also use the run.sh file in the same folder to automate running this | |
code for batch generation of videos, by running: | |
sh ./run.sh | |
""" | |
import time | |
import torch | |
import torch.distributed | |
from diffusers import AutoencoderKLTemporalDecoder | |
from xfuser import xFuserCogVideoXPipeline, xFuserArgs | |
from xfuser.config import FlexibleArgumentParser | |
from xfuser.core.distributed import ( | |
get_world_group, | |
get_data_parallel_rank, | |
get_data_parallel_world_size, | |
get_runtime_state, | |
is_dp_last_group, | |
) | |
from diffusers.utils import export_to_video | |
def main(): | |
parser = FlexibleArgumentParser(description="xFuser Arguments") | |
args = xFuserArgs.add_cli_args(parser).parse_args() | |
engine_args = xFuserArgs.from_cli_args(args) | |
# Check if ulysses_degree is valid | |
num_heads = 30 | |
if engine_args.ulysses_degree > 0 and num_heads % engine_args.ulysses_degree != 0: | |
raise ValueError( | |
f"ulysses_degree ({engine_args.ulysses_degree}) must be a divisor of the number of heads ({num_heads})" | |
) | |
engine_config, input_config = engine_args.create_config() | |
local_rank = get_world_group().local_rank | |
pipe = xFuserCogVideoXPipeline.from_pretrained( | |
pretrained_model_name_or_path=engine_config.model_config.model, | |
engine_config=engine_config, | |
torch_dtype=torch.bfloat16, | |
) | |
if args.enable_sequential_cpu_offload: | |
pipe.enable_model_cpu_offload(gpu_id=local_rank) | |
else: | |
device = torch.device(f"cuda:{local_rank}") | |
pipe = pipe.to(device) | |
# Always enable tiling and slicing to avoid VAE OOM while batch size > 1 | |
pipe.vae.enable_slicing() | |
pipe.vae.enable_tiling() | |
torch.cuda.reset_peak_memory_stats() | |
start_time = time.time() | |
output = pipe( | |
height=input_config.height, | |
width=input_config.width, | |
num_frames=input_config.num_frames, | |
prompt=input_config.prompt, | |
num_inference_steps=input_config.num_inference_steps, | |
generator=torch.Generator().manual_seed(input_config.seed), | |
guidance_scale=6, | |
use_dynamic_cfg=True, | |
).frames[0] | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") | |
parallel_info = ( | |
f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_" | |
f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_" | |
f"tp{engine_args.tensor_parallel_degree}_" | |
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}" | |
) | |
if is_dp_last_group(): | |
world_size = get_data_parallel_world_size() | |
resolution = f"{input_config.width}x{input_config.height}" | |
output_filename = f"results/cogvideox_{parallel_info}_{resolution}.mp4" | |
export_to_video(output, output_filename, fps=8) | |
print(f"output saved to {output_filename}") | |
if get_world_group().rank == get_world_group().world_size - 1: | |
print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB") | |
get_runtime_state().destory_distributed_env() | |
if __name__ == "__main__": | |
main() | |