Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import gradio as gr | |
import torch | |
import numpy as np | |
from transformers import Wav2Vec2Processor | |
from visualise.rendering import RenderTool | |
from data_utils import torch_data | |
from trainer.options import parse_args | |
from trainer.config import load_JsonConfig | |
from nets import init_model, infer # Ensure these functions are properly defined | |
# Set environment variables | |
os.environ['PYOPENGL_PLATFORM'] = 'egl' | |
os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |
sys.path.append(os.getcwd()) | |
# Load the model and configuration | |
def load_models(config_file, face_model_name, face_model_path, body_model_name, body_model_path): | |
args = parse_args() | |
config = load_JsonConfig(config_file) | |
# Initialize models | |
generator_face = init_model(face_model_name, face_model_path, args, config) | |
generator_body = init_model(body_model_name, body_model_path, args, config) | |
# Initialize SMPL-X model | |
smplx_model_params = { | |
'model_path': './visualise/', | |
'model_type': 'smplx', | |
'create_global_orient': True, | |
'create_body_pose': True, | |
'create_betas': True, | |
'num_betas': 300, | |
'create_left_hand_pose': True, | |
'create_right_hand_pose': True, | |
'use_pca': False, | |
'flat_hand_mean': False, | |
'create_expression': True, | |
'num_expression_coeffs': 100, | |
'num_pca_comps': 12, | |
'create_jaw_pose': True, | |
'create_leye_pose': True, | |
'create_reye_pose': True, | |
'create_transl': False, | |
'dtype': torch.float64, | |
} | |
smplx_model = smpl.create(**smplx_model_params).to('cuda') | |
return generator_face, generator_body, smplx_model, config | |
# Inference function | |
def run_inference(audio_file): | |
# Load models | |
generator_face, generator_body, smplx_model, config = load_models( | |
'./config/LS3DCG.json', | |
's2g_LS3DCG', | |
'experiments/2022-10-19-smplx_S2G-LS3DCG/ckpt-99.pth', | |
's2g_LS3DCG', | |
'experiments/2022-10-19-smplx_S2G-LS3DCG/ckpt-99.pth' | |
) | |
# Initialize rendering tool | |
rendertool = RenderTool('visualise/video/' + config.Log.name) | |
# Inference | |
infer(generator_body, generator_face, smplx_model, rendertool, config, audio_file) | |
# Provide output (e.g., path to the rendered video) | |
output_video_path = f'visualise/video/{config.Log.name}/{audio_file.split("/")[-1].split(".")[0]}.npy' | |
return output_video_path | |
# Gradio interface | |
iface = gr.Interface( | |
fn=run_inference, | |
inputs=gr.inputs.Audio(source="upload", type="filepath", label="Upload Audio File"), | |
outputs=gr.outputs.Textbox(label="Output Video Path"), | |
title="Audio to 3D Model Renderer", | |
description="Upload an audio file to generate a 3D model rendering." | |
) | |
if __name__ == "__main__": | |
iface.launch() | |