Spaces:
Build error
Build error
import os | |
import io | |
import math | |
import uuid | |
import base64 | |
import imageio | |
import torch | |
import torchvision | |
from PIL import Image | |
import numpy as np | |
from copy import deepcopy | |
from einops import rearrange | |
import torchvision.transforms as transforms | |
from torchvision.transforms import ToPILImage | |
from hymm_sp.data_kits.audio_dataset import get_audio_feature | |
from hymm_sp.data_kits.ffmpeg_utils import save_video | |
TEMP_DIR = "./temp" | |
if not os.path.exists(TEMP_DIR): | |
os.makedirs(TEMP_DIR, exist_ok=True) | |
def data_preprocess_server(args, image_path, audio_path, prompts, feature_extractor): | |
llava_transform = transforms.Compose( | |
[ | |
transforms.Resize((336, 336), interpolation=transforms.InterpolationMode.BILINEAR), | |
transforms.ToTensor(), | |
transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)), | |
] | |
) | |
""" 生成prompt """ | |
if prompts is None: | |
prompts = "Authentic, Realistic, Natural, High-quality, Lens-Fixed." | |
else: | |
prompts = "Authentic, Realistic, Natural, High-quality, Lens-Fixed, " + prompts | |
fps = 25 | |
img_size = args.image_size | |
ref_image = Image.open(image_path).convert('RGB') | |
# Resize reference image | |
w, h = ref_image.size | |
scale = img_size / min(w, h) | |
new_w = round(w * scale / 64) * 64 | |
new_h = round(h * scale / 64) * 64 | |
if img_size == 704: | |
img_size_long = 1216 | |
if new_w * new_h > img_size * img_size_long: | |
scale = math.sqrt(img_size * img_size_long / w / h) | |
new_w = round(w * scale / 64) * 64 | |
new_h = round(h * scale / 64) * 64 | |
ref_image = ref_image.resize((new_w, new_h), Image.LANCZOS) | |
ref_image = np.array(ref_image) | |
ref_image = torch.from_numpy(ref_image) | |
audio_input, audio_len = get_audio_feature(feature_extractor, audio_path) | |
audio_prompts = audio_input[0] | |
motion_bucket_id_heads = np.array([25] * 4) | |
motion_bucket_id_exps = np.array([30] * 4) | |
motion_bucket_id_heads = torch.from_numpy(motion_bucket_id_heads) | |
motion_bucket_id_exps = torch.from_numpy(motion_bucket_id_exps) | |
fps = torch.from_numpy(np.array(fps)) | |
to_pil = ToPILImage() | |
pixel_value_ref = rearrange(ref_image.clone().unsqueeze(0), "b h w c -> b c h w") # (b c h w) | |
pixel_value_ref_llava = [llava_transform(to_pil(image)) for image in pixel_value_ref] | |
pixel_value_ref_llava = torch.stack(pixel_value_ref_llava, dim=0) | |
batch = { | |
"text_prompt": [prompts], | |
"audio_path": [audio_path], | |
"image_path": [image_path], | |
"fps": fps.unsqueeze(0).to(dtype=torch.float16), | |
"audio_prompts": audio_prompts.unsqueeze(0).to(dtype=torch.float16), | |
"audio_len": [audio_len], | |
"motion_bucket_id_exps": motion_bucket_id_exps.unsqueeze(0), | |
"motion_bucket_id_heads": motion_bucket_id_heads.unsqueeze(0), | |
"pixel_value_ref": pixel_value_ref.unsqueeze(0).to(dtype=torch.float16), | |
"pixel_value_ref_llava": pixel_value_ref_llava.unsqueeze(0).to(dtype=torch.float16) | |
} | |
return batch | |
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8, quality=8): | |
videos = rearrange(videos, "b c t h w -> t b c h w") | |
outputs = [] | |
for x in videos: | |
x = torchvision.utils.make_grid(x, nrow=n_rows) | |
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
if rescale: | |
x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
x = torch.clamp(x,0,1) | |
x = (x * 255).numpy().astype(np.uint8) | |
outputs.append(x) | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
imageio.mimsave(path, outputs, fps=fps, quality=quality) | |
def encode_image_to_base64(image_path): | |
try: | |
with open(image_path, 'rb') as image_file: | |
image_data = image_file.read() | |
encoded_data = base64.b64encode(image_data).decode('utf-8') | |
print(f"Image file '{image_path}' has been successfully encoded to Base64.") | |
return encoded_data | |
except Exception as e: | |
print(f"Error encoding image: {e}") | |
return None | |
def encode_video_to_base64(video_path): | |
try: | |
with open(video_path, 'rb') as video_file: | |
video_data = video_file.read() | |
encoded_data = base64.b64encode(video_data).decode('utf-8') | |
print(f"Video file '{video_path}' has been successfully encoded to Base64.") | |
return encoded_data | |
except Exception as e: | |
print(f"Error encoding video: {e}") | |
return None | |
def encode_wav_to_base64(wav_path): | |
try: | |
with open(wav_path, 'rb') as audio_file: | |
audio_data = audio_file.read() | |
encoded_data = base64.b64encode(audio_data).decode('utf-8') | |
print(f"Audio file '{wav_path}' has been successfully encoded to Base64.") | |
return encoded_data | |
except Exception as e: | |
print(f"Error encoding audio: {e}") | |
return None | |
def encode_pkl_to_base64(pkl_path): | |
try: | |
with open(pkl_path, 'rb') as pkl_file: | |
pkl_data = pkl_file.read() | |
encoded_data = base64.b64encode(pkl_data).decode('utf-8') | |
print(f"Pickle file '{pkl_path}' has been successfully encoded to Base64.") | |
return encoded_data | |
except Exception as e: | |
print(f"Error encoding pickle: {e}") | |
return None | |
def decode_base64_to_image(base64_buffer_str): | |
try: | |
image_data = base64.b64decode(base64_buffer_str) | |
image = Image.open(io.BytesIO(image_data)) | |
image_array = np.array(image) | |
print(f"Image Base64 string has beed succesfully decoded to image.") | |
return image_array | |
except Exception as e: | |
print(f"Error encdecodingoding image: {e}") | |
return None | |
def decode_base64_to_video(base64_buffer_str): | |
try: | |
video_data = base64.b64decode(base64_buffer_str) | |
video_bytes = io.BytesIO(video_data) | |
video_bytes.seek(0) | |
video_reader = imageio.get_reader(video_bytes, 'ffmpeg') | |
video_frames = [frame for frame in video_reader] | |
return video_frames | |
except Exception as e: | |
print(f"Error decoding video: {e}") | |
return None | |
def save_video_base64_to_local(video_path=None, base64_buffer=None, output_video_path=None): | |
if video_path is not None and base64_buffer is None: | |
video_buffer_base64 = encode_video_to_base64(video_path) | |
elif video_path is None and base64_buffer is not None: | |
video_buffer_base64 = deepcopy(base64_buffer) | |
else: | |
print("Please pass either 'video_path' or 'base64_buffer'") | |
return None | |
if video_buffer_base64 is not None: | |
video_data = base64.b64decode(video_buffer_base64) | |
if output_video_path is None: | |
uuid_string = str(uuid.uuid4()) | |
temp_video_path = f'{TEMP_DIR}/{uuid_string}.mp4' | |
else: | |
temp_video_path = output_video_path | |
with open(temp_video_path, 'wb') as video_file: | |
video_file.write(video_data) | |
return temp_video_path | |
else: | |
return None | |
def save_audio_base64_to_local(audio_path=None, base64_buffer=None): | |
if audio_path is not None and base64_buffer is None: | |
audio_buffer_base64 = encode_wav_to_base64(audio_path) | |
elif audio_path is None and base64_buffer is not None: | |
audio_buffer_base64 = deepcopy(base64_buffer) | |
else: | |
print("Please pass either 'audio_path' or 'base64_buffer'") | |
return None | |
if audio_buffer_base64 is not None: | |
audio_data = base64.b64decode(audio_buffer_base64) | |
uuid_string = str(uuid.uuid4()) | |
temp_audio_path = f'{TEMP_DIR}/{uuid_string}.wav' | |
with open(temp_audio_path, 'wb') as audio_file: | |
audio_file.write(audio_data) | |
return temp_audio_path | |
else: | |
return None | |
def save_pkl_base64_to_local(pkl_path=None, base64_buffer=None): | |
if pkl_path is not None and base64_buffer is None: | |
pkl_buffer_base64 = encode_pkl_to_base64(pkl_path) | |
elif pkl_path is None and base64_buffer is not None: | |
pkl_buffer_base64 = deepcopy(base64_buffer) | |
else: | |
print("Please pass either 'pkl_path' or 'base64_buffer'") | |
return None | |
if pkl_buffer_base64 is not None: | |
pkl_data = base64.b64decode(pkl_buffer_base64) | |
uuid_string = str(uuid.uuid4()) | |
temp_pkl_path = f'{TEMP_DIR}/{uuid_string}.pkl' | |
with open(temp_pkl_path, 'wb') as pkl_file: | |
pkl_file.write(pkl_data) | |
return temp_pkl_path | |
else: | |
return None | |
def remove_temp_fles(input_dict): | |
for key, val in input_dict.items(): | |
if "_path" in key and val is not None and os.path.exists(val): | |
os.remove(val) | |
print(f"Remove temporary {key} from {val}") | |
def process_output_dict(output_dict): | |
uuid_string = str(uuid.uuid4()) | |
temp_video_path = f'{TEMP_DIR}/{uuid_string}.mp4' | |
save_video(output_dict["video"], temp_video_path, fps=output_dict.get("save_fps", 25)) | |
# Add audio | |
if output_dict["audio"] is not None and os.path.exists(output_dict["audio"]): | |
output_path = temp_video_path | |
audio_path = output_dict["audio"] | |
save_path = temp_video_path.replace(".mp4", "_audio.mp4") | |
print('='*100) | |
print(f"output_path = {output_path}\n audio_path = {audio_path}\n save_path = {save_path}") | |
os.system(f"ffmpeg -i '{output_path}' -i '{audio_path}' -shortest '{save_path}' -y -loglevel quiet; rm '{output_path}'") | |
else: | |
save_path = temp_video_path | |
video_base64_buffer = encode_video_to_base64(save_path) | |
encoded_output_dict = { | |
"errCode": output_dict["err_code"], | |
"content": [ | |
{ | |
"buffer": video_base64_buffer | |
}, | |
], | |
"info":output_dict["err_msg"], | |
} | |
return encoded_output_dict | |
def save_image_base64_to_local(image_path=None, base64_buffer=None): | |
# Encode image to base64 buffer | |
if image_path is not None and base64_buffer is None: | |
image_buffer_base64 = encode_image_to_base64(image_path) | |
elif image_path is None and base64_buffer is not None: | |
image_buffer_base64 = deepcopy(base64_buffer) | |
else: | |
print("Please pass either 'image_path' or 'base64_buffer'") | |
return None | |
# Decode base64 buffer and save to local disk | |
if image_buffer_base64 is not None: | |
image_data = base64.b64decode(image_buffer_base64) | |
uuid_string = str(uuid.uuid4()) | |
temp_image_path = f'{TEMP_DIR}/{uuid_string}.png' | |
with open(temp_image_path, 'wb') as image_file: | |
image_file.write(image_data) | |
return temp_image_path | |
else: | |
return None | |
def process_input_dict(input_dict): | |
decoded_input_dict = {} | |
decoded_input_dict["save_fps"] = input_dict.get("save_fps", 25) | |
image_base64_buffer = input_dict.get("image_buffer", None) | |
if image_base64_buffer is not None: | |
decoded_input_dict["image_path"] = save_image_base64_to_local( | |
image_path=None, | |
base64_buffer=image_base64_buffer) | |
else: | |
decoded_input_dict["image_path"] = None | |
audio_base64_buffer = input_dict.get("audio_buffer", None) | |
if audio_base64_buffer is not None: | |
decoded_input_dict["audio_path"] = save_audio_base64_to_local( | |
audio_path=None, | |
base64_buffer=audio_base64_buffer) | |
else: | |
decoded_input_dict["audio_path"] = None | |
decoded_input_dict["prompt"] = input_dict.get("text", None) | |
return decoded_input_dict |