TalkSHOWLIVE / app.py
insanecoder69's picture
Update app.py
195b61e verified
raw
history blame
2.81 kB
import os
import sys
import gradio as gr
import torch
import numpy as np
from transformers import Wav2Vec2Processor
from visualise.rendering import RenderTool
from data_utils import torch_data
from trainer.options import parse_args
from trainer.config import load_JsonConfig
from nets import init_model, infer # Ensure these functions are properly defined
# Set environment variables
os.environ['PYOPENGL_PLATFORM'] = 'egl'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
sys.path.append(os.getcwd())
# Load the model and configuration
def load_models(config_file, face_model_name, face_model_path, body_model_name, body_model_path):
args = parse_args()
config = load_JsonConfig(config_file)
# Initialize models
generator_face = init_model(face_model_name, face_model_path, args, config)
generator_body = init_model(body_model_name, body_model_path, args, config)
# Initialize SMPL-X model
smplx_model_params = {
'model_path': './visualise/',
'model_type': 'smplx',
'create_global_orient': True,
'create_body_pose': True,
'create_betas': True,
'num_betas': 300,
'create_left_hand_pose': True,
'create_right_hand_pose': True,
'use_pca': False,
'flat_hand_mean': False,
'create_expression': True,
'num_expression_coeffs': 100,
'num_pca_comps': 12,
'create_jaw_pose': True,
'create_leye_pose': True,
'create_reye_pose': True,
'create_transl': False,
'dtype': torch.float64,
}
smplx_model = smpl.create(**smplx_model_params).to('cuda')
return generator_face, generator_body, smplx_model, config
# Inference function
def run_inference(audio_file):
# Load models
generator_face, generator_body, smplx_model, config = load_models(
'./config/LS3DCG.json',
's2g_LS3DCG',
'experiments/2022-10-19-smplx_S2G-LS3DCG/ckpt-99.pth',
's2g_LS3DCG',
'experiments/2022-10-19-smplx_S2G-LS3DCG/ckpt-99.pth'
)
# Initialize rendering tool
rendertool = RenderTool('visualise/video/' + config.Log.name)
# Inference
infer(generator_body, generator_face, smplx_model, rendertool, config, audio_file)
# Provide output (e.g., path to the rendered video)
output_video_path = f'visualise/video/{config.Log.name}/{audio_file.split("/")[-1].split(".")[0]}.npy'
return output_video_path
# Gradio interface
iface = gr.Interface(
fn=run_inference,
inputs=gr.inputs.Audio(source="upload", type="filepath", label="Upload Audio File"),
outputs=gr.outputs.Textbox(label="Output Video Path"),
title="Audio to 3D Model Renderer",
description="Upload an audio file to generate a 3D model rendering."
)
if __name__ == "__main__":
iface.launch()