Spaces:
Build error
Build error
import os | |
import numpy as np | |
import torch | |
import warnings | |
import threading | |
import traceback | |
import uvicorn | |
from fastapi import FastAPI, Body | |
from pathlib import Path | |
from datetime import datetime | |
import torch.distributed as dist | |
from hymm_gradio.tool_for_end2end import * | |
from hymm_sp.config import parse_args | |
from hymm_sp.sample_inference_audio import HunyuanVideoSampler | |
from hymm_sp.modules.parallel_states import ( | |
initialize_distributed, | |
nccl_info, | |
) | |
from transformers import WhisperModel | |
from transformers import AutoFeatureExtractor | |
from hymm_sp.data_kits.face_align import AlignImage | |
warnings.filterwarnings("ignore") | |
MODEL_OUTPUT_PATH = os.environ.get('MODEL_BASE') | |
app = FastAPI() | |
rlock = threading.RLock() | |
def predict(data=Body(...)): | |
is_acquire = False | |
error_info = "" | |
try: | |
is_acquire = rlock.acquire(blocking=False) | |
if is_acquire: | |
res = predict_wrap(data) | |
return res | |
except Exception as e: | |
error_info = traceback.format_exc() | |
print(error_info) | |
finally: | |
if is_acquire: | |
rlock.release() | |
return {"errCode": -1, "info": "broken"} | |
def predict_wrap(input_dict={}): | |
if nccl_info.sp_size > 1: | |
device = torch.device(f"cuda:{torch.distributed.get_rank()}") | |
rank = local_rank = torch.distributed.get_rank() | |
print(f"sp_size={nccl_info.sp_size}, rank {rank} local_rank {local_rank}") | |
try: | |
print(f"----- rank = {rank}") | |
if rank == 0: | |
input_dict = process_input_dict(input_dict) | |
print('------- start to predict -------') | |
# Parse input arguments | |
image_path = input_dict["image_path"] | |
driving_audio_path = input_dict["audio_path"] | |
prompt = input_dict["prompt"] | |
save_fps = input_dict.get("save_fps", 25) | |
ret_dict = None | |
if image_path is None or driving_audio_path is None: | |
ret_dict = { | |
"errCode": -3, | |
"content": [ | |
{ | |
"buffer": None | |
}, | |
], | |
"info": "input content is not valid", | |
} | |
print(f"errCode: -3, input content is not valid!") | |
return ret_dict | |
# Preprocess input batch | |
torch.cuda.synchronize() | |
a = datetime.now() | |
try: | |
model_kwargs_tmp = data_preprocess_server( | |
args, image_path, driving_audio_path, prompt, feature_extractor | |
) | |
except: | |
ret_dict = { | |
"errCode": -2, | |
"content": [ | |
{ | |
"buffer": None | |
}, | |
], | |
"info": "failed to preprocess input data" | |
} | |
print(f"errCode: -2, preprocess failed!") | |
return ret_dict | |
text_prompt = model_kwargs_tmp["text_prompt"] | |
audio_path = model_kwargs_tmp["audio_path"] | |
image_path = model_kwargs_tmp["image_path"] | |
fps = model_kwargs_tmp["fps"] | |
audio_prompts = model_kwargs_tmp["audio_prompts"] | |
audio_len = model_kwargs_tmp["audio_len"] | |
motion_bucket_id_exps = model_kwargs_tmp["motion_bucket_id_exps"] | |
motion_bucket_id_heads = model_kwargs_tmp["motion_bucket_id_heads"] | |
pixel_value_ref = model_kwargs_tmp["pixel_value_ref"] | |
pixel_value_ref_llava = model_kwargs_tmp["pixel_value_ref_llava"] | |
torch.cuda.synchronize() | |
b = datetime.now() | |
preprocess_time = (b - a).total_seconds() | |
print("="*100) | |
print("preprocess time :", preprocess_time) | |
print("="*100) | |
else: | |
text_prompt = None | |
audio_path = None | |
image_path = None | |
fps = None | |
audio_prompts = None | |
audio_len = None | |
motion_bucket_id_exps = None | |
motion_bucket_id_heads = None | |
pixel_value_ref = None | |
pixel_value_ref_llava = None | |
except: | |
traceback.print_exc() | |
if rank == 0: | |
ret_dict = { | |
"errCode": -1, # Failed to generate video | |
"content":[ | |
{ | |
"buffer": None | |
} | |
], | |
"info": "failed to preprocess", | |
} | |
return ret_dict | |
try: | |
broadcast_params = [ | |
text_prompt, | |
audio_path, | |
image_path, | |
fps, | |
audio_prompts, | |
audio_len, | |
motion_bucket_id_exps, | |
motion_bucket_id_heads, | |
pixel_value_ref, | |
pixel_value_ref_llava, | |
] | |
dist.broadcast_object_list(broadcast_params, src=0) | |
outputs = generate_image_parallel(*broadcast_params) | |
if rank == 0: | |
samples = outputs["samples"] | |
sample = samples[0].unsqueeze(0) | |
sample = sample[:, :, :audio_len[0]] | |
video = sample[0].permute(1, 2, 3, 0).clamp(0, 1).numpy() | |
video = (video * 255.).astype(np.uint8) | |
output_dict = { | |
"err_code": 0, | |
"err_msg": "succeed", | |
"video": video, | |
"audio": input_dict.get("audio_path", None), | |
"save_fps": save_fps, | |
} | |
ret_dict = process_output_dict(output_dict) | |
return ret_dict | |
except: | |
traceback.print_exc() | |
if rank == 0: | |
ret_dict = { | |
"errCode": -1, # Failed to generate video | |
"content":[ | |
{ | |
"buffer": None | |
} | |
], | |
"info": "failed to generate video", | |
} | |
return ret_dict | |
return None | |
def generate_image_parallel(text_prompt, | |
audio_path, | |
image_path, | |
fps, | |
audio_prompts, | |
audio_len, | |
motion_bucket_id_exps, | |
motion_bucket_id_heads, | |
pixel_value_ref, | |
pixel_value_ref_llava | |
): | |
if nccl_info.sp_size > 1: | |
device = torch.device(f"cuda:{torch.distributed.get_rank()}") | |
batch = { | |
"text_prompt": text_prompt, | |
"audio_path": audio_path, | |
"image_path": image_path, | |
"fps": fps, | |
"audio_prompts": audio_prompts, | |
"audio_len": audio_len, | |
"motion_bucket_id_exps": motion_bucket_id_exps, | |
"motion_bucket_id_heads": motion_bucket_id_heads, | |
"pixel_value_ref": pixel_value_ref, | |
"pixel_value_ref_llava": pixel_value_ref_llava | |
} | |
samples = hunyuan_sampler.predict(args, batch, wav2vec, feature_extractor, align_instance) | |
return samples | |
def worker_loop(): | |
while True: | |
predict_wrap() | |
if __name__ == "__main__": | |
audio_args = parse_args() | |
initialize_distributed(audio_args.seed) | |
hunyuan_sampler = HunyuanVideoSampler.from_pretrained( | |
audio_args.ckpt, args=audio_args) | |
args = hunyuan_sampler.args | |
rank = local_rank = 0 | |
device = torch.device("cuda") | |
if nccl_info.sp_size > 1: | |
device = torch.device(f"cuda:{torch.distributed.get_rank()}") | |
rank = local_rank = torch.distributed.get_rank() | |
feature_extractor = AutoFeatureExtractor.from_pretrained(f"{MODEL_OUTPUT_PATH}/ckpts/whisper-tiny/") | |
wav2vec = WhisperModel.from_pretrained(f"{MODEL_OUTPUT_PATH}/ckpts/whisper-tiny/").to(device=device, dtype=torch.float32) | |
wav2vec.requires_grad_(False) | |
BASE_DIR = f'{MODEL_OUTPUT_PATH}/ckpts/det_align/' | |
det_path = os.path.join(BASE_DIR, 'detface.pt') | |
align_instance = AlignImage("cuda", det_path=det_path) | |
if rank == 0: | |
uvicorn.run(app, host="0.0.0.0", port=80) | |
else: | |
worker_loop() | |