Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Update DenseAV/denseav/plotting.py
Browse files- DenseAV/denseav/plotting.py +246 -244
    	
        DenseAV/denseav/plotting.py
    CHANGED
    
    | @@ -1,244 +1,246 @@ | |
| 1 | 
            -
            import os
         | 
| 2 | 
            -
            from collections import defaultdict
         | 
| 3 | 
            -
             | 
| 4 | 
            -
            import matplotlib.colors as mcolors
         | 
| 5 | 
            -
            import matplotlib.pyplot as plt
         | 
| 6 | 
            -
            import numpy as np
         | 
| 7 | 
            -
            import scipy.io.wavfile as wavfile
         | 
| 8 | 
            -
            import torch
         | 
| 9 | 
            -
            import torch.nn.functional as F
         | 
| 10 | 
            -
            import torchvision
         | 
| 11 | 
            -
            from moviepy | 
| 12 | 
            -
            from  | 
| 13 | 
            -
            from  | 
| 14 | 
            -
             | 
| 15 | 
            -
             | 
| 16 | 
            -
             | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
            -
             | 
| 20 | 
            -
                 | 
| 21 | 
            -
             | 
| 22 | 
            -
                 | 
| 23 | 
            -
                -  | 
| 24 | 
            -
                -  | 
| 25 | 
            -
                -  | 
| 26 | 
            -
                 | 
| 27 | 
            -
                 | 
| 28 | 
            -
             | 
| 29 | 
            -
                 | 
| 30 | 
            -
             | 
| 31 | 
            -
                 | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
                    ' | 
| 35 | 
            -
             | 
| 36 | 
            -
             | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
                         | 
| 41 | 
            -
                         | 
| 42 | 
            -
             | 
| 43 | 
            -
             | 
| 44 | 
            -
                     | 
| 45 | 
            -
             | 
| 46 | 
            -
                     | 
| 47 | 
            -
                     | 
| 48 | 
            -
                     | 
| 49 | 
            -
                     | 
| 50 | 
            -
                     | 
| 51 | 
            -
             | 
| 52 | 
            -
                     | 
| 53 | 
            -
             | 
| 54 | 
            -
             | 
| 55 | 
            -
                         | 
| 56 | 
            -
                         | 
| 57 | 
            -
             | 
| 58 | 
            -
             | 
| 59 | 
            -
             | 
| 60 | 
            -
             | 
| 61 | 
            -
             | 
| 62 | 
            -
             | 
| 63 | 
            -
             | 
| 64 | 
            -
             | 
| 65 | 
            -
                     | 
| 66 | 
            -
                     | 
| 67 | 
            -
                     | 
| 68 | 
            -
             | 
| 69 | 
            -
             | 
| 70 | 
            -
             | 
| 71 | 
            -
             | 
| 72 | 
            -
             | 
| 73 | 
            -
             | 
| 74 | 
            -
             | 
| 75 | 
            -
             | 
| 76 | 
            -
                     | 
| 77 | 
            -
             | 
| 78 | 
            -
                     | 
| 79 | 
            -
             | 
| 80 | 
            -
             | 
| 81 | 
            -
             | 
| 82 | 
            -
                         | 
| 83 | 
            -
             | 
| 84 | 
            -
             | 
| 85 | 
            -
                             | 
| 86 | 
            -
             | 
| 87 | 
            -
             | 
| 88 | 
            -
             | 
| 89 | 
            -
                         | 
| 90 | 
            -
             | 
| 91 | 
            -
             | 
| 92 | 
            -
             | 
| 93 | 
            -
                                 | 
| 94 | 
            -
             | 
| 95 | 
            -
             | 
| 96 | 
            -
             | 
| 97 | 
            -
             | 
| 98 | 
            -
             | 
| 99 | 
            -
             | 
| 100 | 
            -
             | 
| 101 | 
            -
             | 
| 102 | 
            -
             | 
| 103 | 
            -
                 | 
| 104 | 
            -
                 | 
| 105 | 
            -
                plasma_with_alpha | 
| 106 | 
            -
                 | 
| 107 | 
            -
             | 
| 108 | 
            -
             | 
| 109 | 
            -
             | 
| 110 | 
            -
             | 
| 111 | 
            -
             | 
| 112 | 
            -
                 | 
| 113 | 
            -
                custom_cmap | 
| 114 | 
            -
                 | 
| 115 | 
            -
                 | 
| 116 | 
            -
                custom_cmap[threshold_index | 
| 117 | 
            -
                 | 
| 118 | 
            -
                 | 
| 119 | 
            -
             | 
| 120 | 
            -
             | 
| 121 | 
            -
             | 
| 122 | 
            -
             | 
| 123 | 
            -
             | 
| 124 | 
            -
                 | 
| 125 | 
            -
                 | 
| 126 | 
            -
                plasma_with_alpha | 
| 127 | 
            -
                 | 
| 128 | 
            -
             | 
| 129 | 
            -
             | 
| 130 | 
            -
             | 
| 131 | 
            -
             | 
| 132 | 
            -
             | 
| 133 | 
            -
                ' | 
| 134 | 
            -
                ' | 
| 135 | 
            -
             | 
| 136 | 
            -
             | 
| 137 | 
            -
             | 
| 138 | 
            -
             | 
| 139 | 
            -
             | 
| 140 | 
            -
                ' | 
| 141 | 
            -
                ' | 
| 142 | 
            -
             | 
| 143 | 
            -
             | 
| 144 | 
            -
             | 
| 145 | 
            -
             | 
| 146 | 
            -
             | 
| 147 | 
            -
             | 
| 148 | 
            -
                 | 
| 149 | 
            -
                 | 
| 150 | 
            -
                sims_all = sims_all | 
| 151 | 
            -
                 | 
| 152 | 
            -
                 | 
| 153 | 
            -
                 | 
| 154 | 
            -
                 | 
| 155 | 
            -
             | 
| 156 | 
            -
             | 
| 157 | 
            -
                     | 
| 158 | 
            -
                     | 
| 159 | 
            -
                     | 
| 160 | 
            -
             | 
| 161 | 
            -
             | 
| 162 | 
            -
             | 
| 163 | 
            -
             | 
| 164 | 
            -
             | 
| 165 | 
            -
                 | 
| 166 | 
            -
             | 
| 167 | 
            -
                 | 
| 168 | 
            -
             | 
| 169 | 
            -
                 | 
| 170 | 
            -
             | 
| 171 | 
            -
                 | 
| 172 | 
            -
             | 
| 173 | 
            -
                 | 
| 174 | 
            -
             | 
| 175 | 
            -
                sims_1 = sims_1 | 
| 176 | 
            -
             | 
| 177 | 
            -
                 | 
| 178 | 
            -
             | 
| 179 | 
            -
                sims_2 = sims_2 | 
| 180 | 
            -
             | 
| 181 | 
            -
                 | 
| 182 | 
            -
             | 
| 183 | 
            -
                 | 
| 184 | 
            -
             | 
| 185 | 
            -
                 | 
| 186 | 
            -
             | 
| 187 | 
            -
             | 
| 188 | 
            -
                     | 
| 189 | 
            -
                     | 
| 190 | 
            -
                     | 
| 191 | 
            -
             | 
| 192 | 
            -
             | 
| 193 | 
            -
             | 
| 194 | 
            -
             | 
| 195 | 
            -
             | 
| 196 | 
            -
                                    | 
| 197 | 
            -
                                    | 
| 198 | 
            -
                                    | 
| 199 | 
            -
                                    | 
| 200 | 
            -
                                    | 
| 201 | 
            -
             | 
| 202 | 
            -
             | 
| 203 | 
            -
             | 
| 204 | 
            -
                     | 
| 205 | 
            -
             | 
| 206 | 
            -
             | 
| 207 | 
            -
             | 
| 208 | 
            -
             | 
| 209 | 
            -
                     | 
| 210 | 
            -
                     | 
| 211 | 
            -
                     | 
| 212 | 
            -
             | 
| 213 | 
            -
             | 
| 214 | 
            -
             | 
| 215 | 
            -
             | 
| 216 | 
            -
                     | 
| 217 | 
            -
                     | 
| 218 | 
            -
                     | 
| 219 | 
            -
             | 
| 220 | 
            -
             | 
| 221 | 
            -
             | 
| 222 | 
            -
                 | 
| 223 | 
            -
             | 
| 224 | 
            -
             | 
| 225 | 
            -
                     | 
| 226 | 
            -
                     | 
| 227 | 
            -
             | 
| 228 | 
            -
             | 
| 229 | 
            -
             | 
| 230 | 
            -
             | 
| 231 | 
            -
                     | 
| 232 | 
            -
                     | 
| 233 | 
            -
                     | 
| 234 | 
            -
             | 
| 235 | 
            -
             | 
| 236 | 
            -
             | 
| 237 | 
            -
             | 
| 238 | 
            -
             | 
| 239 | 
            -
                 | 
| 240 | 
            -
                 | 
| 241 | 
            -
             | 
| 242 | 
            -
             | 
| 243 | 
            -
               | 
| 244 | 
            -
             | 
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            from collections import defaultdict
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import matplotlib.colors as mcolors
         | 
| 5 | 
            +
            import matplotlib.pyplot as plt
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            import scipy.io.wavfile as wavfile
         | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            import torch.nn.functional as F
         | 
| 10 | 
            +
            import torchvision
         | 
| 11 | 
            +
            from moviepy import *
         | 
| 12 | 
            +
            from moviepy.editor import VideoFileClip, AudioFileClip
         | 
| 13 | 
            +
            from base64 import b64encode
         | 
| 14 | 
            +
            from DenseAV.denseav.shared import pca
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def write_video_with_audio(video_frames, audio_array, video_fps, audio_fps, output_path):
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
                Writes video frames and audio to a specified path.
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                Parameters:
         | 
| 23 | 
            +
                - video_frames: torch.Tensor of shape (num_frames, height, width, channels)
         | 
| 24 | 
            +
                - audio_array: torch.Tensor of shape (num_samples, num_channels)
         | 
| 25 | 
            +
                - video_fps: int, frames per second of the video
         | 
| 26 | 
            +
                - audio_fps: int, sample rate of the audio
         | 
| 27 | 
            +
                - output_path: str, path to save the final video with audio
         | 
| 28 | 
            +
                """
         | 
| 29 | 
            +
                os.makedirs(os.path.dirname(output_path), exist_ok=True)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                temp_video_path = output_path.replace('.mp4', '_temp.mp4')
         | 
| 32 | 
            +
                temp_audio_path = output_path.replace('.mp4', '_temp_audio.wav')
         | 
| 33 | 
            +
                video_options = {
         | 
| 34 | 
            +
                    'crf': '23',
         | 
| 35 | 
            +
                    'preset': 'slow',
         | 
| 36 | 
            +
                    'bit_rate': '1000k'}
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                if audio_array is not None:
         | 
| 39 | 
            +
                    torchvision.io.write_video(
         | 
| 40 | 
            +
                        filename=temp_video_path,
         | 
| 41 | 
            +
                        video_array=video_frames,
         | 
| 42 | 
            +
                        fps=video_fps,
         | 
| 43 | 
            +
                        options=video_options
         | 
| 44 | 
            +
                    )
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    wavfile.write(temp_audio_path, audio_fps, audio_array.cpu().to(torch.float64).permute(1, 0).numpy())
         | 
| 47 | 
            +
                    video_clip = VideoFileClip(temp_video_path)
         | 
| 48 | 
            +
                    audio_clip = AudioFileClip(temp_audio_path)
         | 
| 49 | 
            +
                    final_clip = video_clip.set_audio(audio_clip)
         | 
| 50 | 
            +
                    final_clip.write_videofile(output_path, codec='libx264', verbose=False)
         | 
| 51 | 
            +
                    os.remove(temp_video_path)
         | 
| 52 | 
            +
                    os.remove(temp_audio_path)
         | 
| 53 | 
            +
                else:
         | 
| 54 | 
            +
                    torchvision.io.write_video(
         | 
| 55 | 
            +
                        filename=output_path,
         | 
| 56 | 
            +
                        video_array=video_frames,
         | 
| 57 | 
            +
                        fps=video_fps,
         | 
| 58 | 
            +
                        options=video_options
         | 
| 59 | 
            +
                    )
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            def alpha_blend_layers(layers):
         | 
| 63 | 
            +
                blended_image = layers[0]
         | 
| 64 | 
            +
                for layer in layers[1:]:
         | 
| 65 | 
            +
                    rgb1, alpha1 = blended_image[:, :3, :, :], blended_image[:, 3:4, :, :]
         | 
| 66 | 
            +
                    rgb2, alpha2 = layer[:, :3, :, :], layer[:, 3:4, :, :]
         | 
| 67 | 
            +
                    alpha_out = alpha2 + alpha1 * (1 - alpha2)
         | 
| 68 | 
            +
                    rgb_out = (rgb2 * alpha2 + rgb1 * alpha1 * (1 - alpha2)) / alpha_out.clamp(min=1e-7)
         | 
| 69 | 
            +
                    blended_image = torch.cat([rgb_out, alpha_out], dim=1)
         | 
| 70 | 
            +
                return (blended_image[:, :3] * 255).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            def _prep_sims_for_plotting(sim_by_head, frames):
         | 
| 74 | 
            +
                with torch.no_grad():
         | 
| 75 | 
            +
                    results = defaultdict(list)
         | 
| 76 | 
            +
                    n_frames, _, vh, vw = frames.shape
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    sims = sim_by_head.max(dim=1).values
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    n_audio_feats = sims.shape[-1]
         | 
| 81 | 
            +
                    for frame_num in range(n_frames):
         | 
| 82 | 
            +
                        selected_audio_feat = int((frame_num / n_frames) * n_audio_feats)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                        selected_sim = F.interpolate(
         | 
| 85 | 
            +
                            sims[frame_num, :, :, selected_audio_feat].unsqueeze(0).unsqueeze(0),
         | 
| 86 | 
            +
                            size=(vh, vw),
         | 
| 87 | 
            +
                            mode="bicubic")
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                        results["sims_all"].append(selected_sim)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                        for head in range(sim_by_head.shape[1]):
         | 
| 92 | 
            +
                            selected_sim = F.interpolate(
         | 
| 93 | 
            +
                                sim_by_head[frame_num, head, :, :, selected_audio_feat].unsqueeze(0).unsqueeze(0),
         | 
| 94 | 
            +
                                size=(vh, vw),
         | 
| 95 | 
            +
                                mode="bicubic")
         | 
| 96 | 
            +
                            results[f"sims_{head + 1}"].append(selected_sim)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    results = {k: torch.cat(v, dim=0) for k, v in results.items()}
         | 
| 99 | 
            +
                    return results
         | 
| 100 | 
            +
             | 
| 101 | 
            +
             | 
| 102 | 
            +
            def get_plasma_with_alpha():
         | 
| 103 | 
            +
                plasma = plt.cm.plasma(np.linspace(0, 1, 256))
         | 
| 104 | 
            +
                alphas = np.linspace(0, 1, 256)
         | 
| 105 | 
            +
                plasma_with_alpha = np.zeros((256, 4))
         | 
| 106 | 
            +
                plasma_with_alpha[:, 0:3] = plasma[:, 0:3]
         | 
| 107 | 
            +
                plasma_with_alpha[:, 3] = alphas
         | 
| 108 | 
            +
                return mcolors.ListedColormap(plasma_with_alpha)
         | 
| 109 | 
            +
             | 
| 110 | 
            +
             | 
| 111 | 
            +
            def get_inferno_with_alpha_2(alpha=0.5, k=30):
         | 
| 112 | 
            +
                k_fraction = k / 100.0
         | 
| 113 | 
            +
                custom_cmap = np.zeros((256, 4))
         | 
| 114 | 
            +
                threshold_index = int(k_fraction * 256)
         | 
| 115 | 
            +
                custom_cmap[:threshold_index, :3] = 0  # RGB values for black
         | 
| 116 | 
            +
                custom_cmap[:threshold_index, 3] = alpha  # Alpha value
         | 
| 117 | 
            +
                remaining_inferno = plt.cm.inferno(np.linspace(0, 1, 256 - threshold_index))
         | 
| 118 | 
            +
                custom_cmap[threshold_index:, :3] = remaining_inferno[:, :3]
         | 
| 119 | 
            +
                custom_cmap[threshold_index:, 3] = alpha  # Alpha value
         | 
| 120 | 
            +
                return mcolors.ListedColormap(custom_cmap)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
             | 
| 123 | 
            +
            def get_inferno_with_alpha():
         | 
| 124 | 
            +
                plasma = plt.cm.inferno(np.linspace(0, 1, 256))
         | 
| 125 | 
            +
                alphas = np.linspace(0, 1, 256)
         | 
| 126 | 
            +
                plasma_with_alpha = np.zeros((256, 4))
         | 
| 127 | 
            +
                plasma_with_alpha[:, 0:3] = plasma[:, 0:3]
         | 
| 128 | 
            +
                plasma_with_alpha[:, 3] = alphas
         | 
| 129 | 
            +
                return mcolors.ListedColormap(plasma_with_alpha)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
             | 
| 132 | 
            +
            red_cmap = mcolors.LinearSegmentedColormap('RedMap', segmentdata={
         | 
| 133 | 
            +
                'red': [(0.0, 1.0, 1.0), (1.0, 1.0, 1.0)],
         | 
| 134 | 
            +
                'green': [(0.0, 0.0, 0.0), (1.0, 0.0, 0.0)],
         | 
| 135 | 
            +
                'blue': [(0.0, 0.0, 0.0), (1.0, 0.0, 0.0)],
         | 
| 136 | 
            +
                'alpha': [(0.0, 0.0, 0.0), (1.0, 1.0, 1.0)]
         | 
| 137 | 
            +
            })
         | 
| 138 | 
            +
             | 
| 139 | 
            +
            blue_cmap = mcolors.LinearSegmentedColormap('BlueMap', segmentdata={
         | 
| 140 | 
            +
                'red': [(0.0, 0.0, 0.0), (1.0, 0.0, 0.0)],
         | 
| 141 | 
            +
                'green': [(0.0, 0.0, 0.0), (1.0, 0.0, 0.0)],
         | 
| 142 | 
            +
                'blue': [(0.0, 1.0, 1.0), (1.0, 1.0, 1.0)],
         | 
| 143 | 
            +
                'alpha': [(0.0, 0.0, 0.0), (1.0, 1.0, 1.0)]
         | 
| 144 | 
            +
            })
         | 
| 145 | 
            +
             | 
| 146 | 
            +
             | 
| 147 | 
            +
            def plot_attention_video(sims_by_head, frames, audio, video_fps, audio_fps, output_filename):
         | 
| 148 | 
            +
                prepped_sims = _prep_sims_for_plotting(sims_by_head, frames)
         | 
| 149 | 
            +
                n_frames, _, vh, vw = frames.shape
         | 
| 150 | 
            +
                sims_all = prepped_sims["sims_all"].clamp_min(0)
         | 
| 151 | 
            +
                sims_all -= sims_all.min()
         | 
| 152 | 
            +
                sims_all = sims_all / sims_all.max()
         | 
| 153 | 
            +
                cmap = get_inferno_with_alpha()
         | 
| 154 | 
            +
                layer1 = torch.cat([frames, torch.ones(n_frames, 1, vh, vw)], axis=1)
         | 
| 155 | 
            +
                layer2 = torch.tensor(cmap(sims_all.squeeze().detach().cpu())).permute(0, 3, 1, 2)
         | 
| 156 | 
            +
                write_video_with_audio(
         | 
| 157 | 
            +
                    alpha_blend_layers([layer1, layer2]),
         | 
| 158 | 
            +
                    audio,
         | 
| 159 | 
            +
                    video_fps,
         | 
| 160 | 
            +
                    audio_fps,
         | 
| 161 | 
            +
                    output_filename)
         | 
| 162 | 
            +
             | 
| 163 | 
            +
             | 
| 164 | 
            +
            def plot_2head_attention_video(sims_by_head, frames, audio, video_fps, audio_fps, output_filename):
         | 
| 165 | 
            +
                prepped_sims = _prep_sims_for_plotting(sims_by_head, frames)
         | 
| 166 | 
            +
                sims_1 = prepped_sims["sims_1"]
         | 
| 167 | 
            +
                sims_2 = prepped_sims["sims_2"]
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                n_frames, _, vh, vw = frames.shape
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                mask = sims_1 > sims_2
         | 
| 172 | 
            +
                sims_1 *= mask
         | 
| 173 | 
            +
                sims_2 *= (~mask)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                sims_1 = sims_1.clamp_min(0)
         | 
| 176 | 
            +
                sims_1 -= sims_1.min()
         | 
| 177 | 
            +
                sims_1 = sims_1 / sims_1.max()
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                sims_2 = sims_2.clamp_min(0)
         | 
| 180 | 
            +
                sims_2 -= sims_2.min()
         | 
| 181 | 
            +
                sims_2 = sims_2 / sims_2.max()
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                layer1 = torch.cat([frames, torch.ones(n_frames, 1, vh, vw)], axis=1)
         | 
| 184 | 
            +
                layer2_head1 = torch.tensor(red_cmap(sims_1.squeeze().detach().cpu())).permute(0, 3, 1, 2)
         | 
| 185 | 
            +
                layer2_head2 = torch.tensor(blue_cmap(sims_2.squeeze().detach().cpu())).permute(0, 3, 1, 2)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                write_video_with_audio(
         | 
| 188 | 
            +
                    alpha_blend_layers([layer1, layer2_head1, layer2_head2]),
         | 
| 189 | 
            +
                    audio,
         | 
| 190 | 
            +
                    video_fps,
         | 
| 191 | 
            +
                    audio_fps,
         | 
| 192 | 
            +
                    output_filename)
         | 
| 193 | 
            +
             | 
| 194 | 
            +
             | 
| 195 | 
            +
            def plot_feature_video(image_feats,
         | 
| 196 | 
            +
                                   audio_feats,
         | 
| 197 | 
            +
                                   frames,
         | 
| 198 | 
            +
                                   audio,
         | 
| 199 | 
            +
                                   video_fps,
         | 
| 200 | 
            +
                                   audio_fps,
         | 
| 201 | 
            +
                                   video_filename,
         | 
| 202 | 
            +
                                   audio_filename):
         | 
| 203 | 
            +
                with torch.no_grad():
         | 
| 204 | 
            +
                    image_feats_ = image_feats.cpu()
         | 
| 205 | 
            +
                    audio_feats_ = audio_feats.cpu()
         | 
| 206 | 
            +
                    [red_img_feats, red_audio_feats], _ = pca([
         | 
| 207 | 
            +
                        image_feats_,
         | 
| 208 | 
            +
                        audio_feats_,  # .tile(image_feats_.shape[0], 1, 1, 1)
         | 
| 209 | 
            +
                    ])
         | 
| 210 | 
            +
                    _, _, vh, vw = frames.shape
         | 
| 211 | 
            +
                    red_img_feats = F.interpolate(red_img_feats, size=(vh, vw), mode="bicubic")
         | 
| 212 | 
            +
                    red_audio_feats = red_audio_feats[0].unsqueeze(0)
         | 
| 213 | 
            +
                    red_audio_feats = F.interpolate(red_audio_feats, size=(50, red_img_feats.shape[0]), mode="bicubic")
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                write_video_with_audio(
         | 
| 216 | 
            +
                    (red_img_feats.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8),
         | 
| 217 | 
            +
                    audio,
         | 
| 218 | 
            +
                    video_fps,
         | 
| 219 | 
            +
                    audio_fps,
         | 
| 220 | 
            +
                    video_filename)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                red_audio_feats_expanded = red_audio_feats.tile(red_img_feats.shape[0], 1, 1, 1)
         | 
| 223 | 
            +
                red_audio_feats_expanded = F.interpolate(red_audio_feats_expanded, scale_factor=6, mode="bicubic")
         | 
| 224 | 
            +
                for i in range(red_img_feats.shape[0]):
         | 
| 225 | 
            +
                    center_index = i * 6
         | 
| 226 | 
            +
                    min_index = max(center_index - 2, 0)
         | 
| 227 | 
            +
                    max_index = min(center_index + 2, red_audio_feats_expanded.shape[-1])
         | 
| 228 | 
            +
                    red_audio_feats_expanded[i, :, :, min_index:max_index] = 1
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                write_video_with_audio(
         | 
| 231 | 
            +
                    (red_audio_feats_expanded.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8),
         | 
| 232 | 
            +
                    audio,
         | 
| 233 | 
            +
                    video_fps,
         | 
| 234 | 
            +
                    audio_fps,
         | 
| 235 | 
            +
                    audio_filename)
         | 
| 236 | 
            +
             | 
| 237 | 
            +
             | 
| 238 | 
            +
            def display_video_in_notebook(path):
         | 
| 239 | 
            +
                from IPython.display import HTML, display
         | 
| 240 | 
            +
                mp4 = open(path, 'rb').read()
         | 
| 241 | 
            +
                data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
         | 
| 242 | 
            +
                display(HTML("""
         | 
| 243 | 
            +
              <video width=400 controls>
         | 
| 244 | 
            +
                    <source src="%s" type="video/mp4">
         | 
| 245 | 
            +
              </video>
         | 
| 246 | 
            +
              """ % data_url))
         |