File size: 3,972 Bytes
5421a47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
import torch
import yaml
import os
from pathlib import Path
from modules.fslip import FastLip
from modules.base_model import BaseModel
import numpy as np
import cv2
from moviepy.editor import VideoFileClip
import tempfile

# Load configuration
def load_config():
    with open('configs/lipgen/grid/lipgen_grid.yaml', 'r') as f:
        config = yaml.safe_load(f)
    return config

# Initialize model
def init_model():
    config = load_config()
    model = FastLip(
        arch=config['arch'],
        dictionary=None,  # We'll need to implement a simple dictionary
        out_dims=None
    )
    # Load checkpoint
    checkpoint = torch.load('checkpoints/lipgen_grid.pt', map_location='cpu')
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    return model

# Process video frames
def process_video(video_path, target_language):
    model = init_model()
    
    # Load video
    video = VideoFileClip(video_path)
    frames = []
    for frame in video.iter_frames():
        # Resize frame to match model input size (80x160)
        frame = cv2.resize(frame, (160, 80))
        frames.append(frame)
    
    # Convert frames to tensor
    frames = torch.FloatTensor(np.array(frames)).permute(0, 3, 1, 2) / 255.0
    
    # Process with model
    with torch.no_grad():
        # TODO: Implement text processing for target language
        # For now, we'll just return the processed frames
        output = model(frames.unsqueeze(0))
    
    # Convert output to video
    output_frames = output['lip_out'].squeeze(0).cpu().numpy()
    output_frames = (output_frames * 255).astype(np.uint8)
    
    # Save to temporary file
    temp_dir = tempfile.mkdtemp()
    output_path = os.path.join(temp_dir, 'output.mp4')
    
    # Create video from frames
    height, width = output_frames.shape[2:4]
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, 25.0, (width, height))
    
    for frame in output_frames:
        frame = frame.transpose(1, 2, 0)
        out.write(frame)
    out.release()
    
    return output_path

# Create Gradio interface
def create_interface():
    with gr.Blocks(title="ParaLip Video Dubbing") as demo:
        gr.Markdown("""
        # ParaLip Video Dubbing
        Upload a video and select a target language to create a dubbed version.
        """)
        
        with gr.Row():
            with gr.Column():
                video_input = gr.Video(label="Upload Video")
                language = gr.Dropdown(
                    choices=["spanish", "french", "german", "italian", "portuguese"],
                    value="spanish",
                    label="Target Language"
                )
                dub_button = gr.Button("Dub Video")
            
            with gr.Column():
                status = gr.Textbox(label="Status")
                video_output = gr.Video(label="Dubbed Video")
        
        def process_video_wrapper(video_file, target_lang):
            if video_file is None:
                return "Please upload a video file", None
            
            try:
                # Save uploaded file temporarily
                temp_path = Path("temp_video.mp4")
                with open(temp_path, "wb") as f:
                    f.write(video_file.read())
                
                # Process video
                output_path = process_video(temp_path, target_lang)
                
                # Clean up
                temp_path.unlink()
                
                return "Dubbing completed successfully!", output_path
                
            except Exception as e:
                return f"Error during dubbing: {str(e)}", None
        
        dub_button.click(
            fn=process_video_wrapper,
            inputs=[video_input, language],
            outputs=[status, video_output]
        )
    
    return demo

if __name__ == "__main__":
    demo = create_interface()
    demo.launch()