Spaces:
Paused
Paused
| import gradio as gr | |
| from pathlib import Path | |
| from scripts.inference import main | |
| from omegaconf import OmegaConf | |
| import argparse | |
| from datetime import datetime | |
| import subprocess | |
| import os | |
| CONFIG_PATH = Path("configs/unet/second_stage.yaml") | |
| CHECKPOINT_PATH = Path("checkpoints/latentsync_unet.pt") | |
| subprocess.run(["huggingface-cli", "download", "Hyathi/LatentSync", "--local-dir", "checkpoints", "--exclude", "*.git*", "README.md", "--token", os.environ["HF_TOKEN"]]) | |
| def process_video( | |
| video_path, | |
| audio_path, | |
| guidance_scale, | |
| inference_steps, | |
| seed, | |
| checkpoint_file, | |
| mask_file, | |
| ): | |
| # Create the temp directory if it doesn't exist | |
| output_dir = Path("./temp") | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Use selected checkpoint or fall back to default | |
| checkpoint_path = Path("checkpoints/unetFiles") / checkpoint_file if checkpoint_file else CHECKPOINT_PATH | |
| # Get mask path | |
| mask_path = Path("masks") / mask_file if mask_file else None | |
| # Convert paths to absolute Path objects and normalize them | |
| video_file_path = Path(video_path) | |
| video_path = video_file_path.absolute().as_posix() | |
| audio_path = Path(audio_path).absolute().as_posix() | |
| current_time = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| # Set the output path for the processed video | |
| output_path = str( | |
| output_dir / f"{video_file_path.stem}_{current_time}.mp4" | |
| ) # Change the filename as needed | |
| config = OmegaConf.load(CONFIG_PATH) | |
| config["run"].update( | |
| { | |
| "guidance_scale": guidance_scale, | |
| "inference_steps": inference_steps, | |
| } | |
| ) | |
| # Parse the arguments | |
| args = create_args(video_path, audio_path, output_path, guidance_scale, seed, checkpoint_path, mask_path) | |
| try: | |
| result = main( | |
| config=config, | |
| args=args, | |
| ) | |
| print("Processing completed successfully.") | |
| return output_path # Ensure the output path is returned | |
| except Exception as e: | |
| print(f"Error during processing: {str(e)}") | |
| raise gr.Error(f"Error during processing: {str(e)}") | |
| def create_args( | |
| video_path: str, audio_path: str, output_path: str, guidance_scale: float, seed: int, | |
| checkpoint_path: Path, mask_path: Path | |
| ) -> argparse.Namespace: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--inference_ckpt_path", type=str, required=True) | |
| parser.add_argument("--video_path", type=str, required=True) | |
| parser.add_argument("--audio_path", type=str, required=True) | |
| parser.add_argument("--video_out_path", type=str, required=True) | |
| parser.add_argument("--guidance_scale", type=float, default=1.0) | |
| parser.add_argument("--seed", type=int, default=1247) | |
| parser.add_argument("--mask_path", type=str, required=False) | |
| return parser.parse_args( | |
| [ | |
| "--inference_ckpt_path", | |
| checkpoint_path.absolute().as_posix(), | |
| "--video_path", | |
| video_path, | |
| "--audio_path", | |
| audio_path, | |
| "--video_out_path", | |
| output_path, | |
| "--guidance_scale", | |
| str(guidance_scale), | |
| "--seed", | |
| str(seed), | |
| "--mask_path", | |
| mask_path.absolute().as_posix() if mask_path else "", | |
| ] | |
| ) | |
| # Add this function to get checkpoint files | |
| def get_checkpoint_files(): | |
| unet_files_dir = Path("checkpoints/unetFiles") | |
| if not unet_files_dir.exists(): | |
| return [] | |
| return [f.name for f in unet_files_dir.glob("*.pt")] | |
| # Add this function to get mask files | |
| def get_mask_files(): | |
| masks_dir = Path("masks") | |
| if not masks_dir.exists(): | |
| return [] | |
| return [f.name for f in masks_dir.glob("*.png")] # Assuming masks are PNG files | |
| # Create Gradio interface | |
| with gr.Blocks(title="SoundImage") as demo: | |
| gr.Markdown( | |
| """ | |
| # SoundImage: Audio Conditioned Video Generation | |
| Upload a video and audio file to process with SoundImage model. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Add checkpoint and mask selectors | |
| checkpoint_dropdown = gr.Dropdown( | |
| choices=get_checkpoint_files(), | |
| label="Select Checkpoint", | |
| value=get_checkpoint_files()[0] if get_checkpoint_files() else None | |
| ) | |
| mask_dropdown = gr.Dropdown( # New dropdown for masks | |
| choices=get_mask_files(), | |
| label="Select Mask", | |
| value=get_mask_files()[0] if get_mask_files() else None | |
| ) | |
| video_input = gr.Video(label="Input Video") | |
| audio_input = gr.Audio(label="Input Audio", type="filepath") | |
| with gr.Row(): | |
| guidance_scale = gr.Slider( | |
| minimum=0.1, | |
| maximum=3.0, | |
| value=1.0, | |
| step=0.1, | |
| label="Guidance Scale", | |
| ) | |
| inference_steps = gr.Slider( | |
| minimum=1, maximum=50, value=20, step=1, label="Inference Steps" | |
| ) | |
| with gr.Row(): | |
| seed = gr.Number(value=1247, label="Random Seed", precision=0) | |
| process_btn = gr.Button("Process Video") | |
| with gr.Column(): | |
| video_output = gr.Video(label="Output Video") | |
| # gr.Examples( | |
| # examples=[ | |
| # ["assets/demo1_video.mp4", "assets/demo1_audio.wav"], | |
| # ["assets/demo2_video.mp4", "assets/demo2_audio.wav"], | |
| # ["assets/demo3_video.mp4", "assets/demo3_audio.wav"], | |
| # ], | |
| # inputs=[video_input, audio_input], | |
| # ) | |
| process_btn.click( | |
| fn=process_video, | |
| inputs=[ | |
| video_input, | |
| audio_input, | |
| guidance_scale, | |
| inference_steps, | |
| seed, | |
| checkpoint_dropdown, | |
| mask_dropdown, # Add mask_dropdown to inputs | |
| ], | |
| outputs=video_output, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(inbrowser=True, share=True) | |