Spaces:
Build error
Build error
import skvideo | |
# assert skvideo.__version__ >= "1.1.11" | |
import os | |
import skvideo.io | |
import cv2 | |
# install the following packages: # | |
# conda install -c conda-forge scikit-video ffmpeg # | |
import os | |
import torch | |
import torchvision | |
from PIL import Image | |
import numpy as np | |
from einops import rearrange | |
class VideoUtils(object): | |
def __init__(self, video_path=None, output_video_path=None, bit_rate='origin', fps=25): | |
if video_path is not None: | |
meta_data = skvideo.io.ffprobe(video_path) | |
# avg_frame_rate = meta_data['video']['@r_frame_rate'] | |
# a, b = avg_frame_rate.split('/') | |
# fps = float(a) / float(b) | |
# fps = 25 | |
codec_name = 'libx264' | |
# codec_name = meta_data['video'].get('@codec_name') | |
# if codec_name=='hevc': | |
# codec_name='h264' | |
# profile = meta_data['video'].get('@profile') | |
color_space = meta_data['video'].get('@color_space') | |
color_transfer = meta_data['video'].get('@color_transfer') | |
color_primaries = meta_data['video'].get('@color_primaries') | |
color_range = meta_data['video'].get('@color_range') | |
pix_fmt = meta_data['video'].get('@pix_fmt') | |
if bit_rate=='origin': | |
bit_rate = meta_data['video'].get('@bit_rate') | |
else: | |
bit_rate=None | |
if pix_fmt is None: | |
pix_fmt = 'yuv420p' | |
reader_output_dict = {'-r': str(fps)} | |
writer_input_dict = {'-r': str(fps)} | |
writer_output_dict = {'-pix_fmt': pix_fmt, '-r': str(fps), '-vcodec':str(codec_name)} | |
# if bit_rate is not None: | |
# writer_output_dict['-b:v'] = bit_rate | |
writer_output_dict['-crf'] = '17' | |
# if video has alpha channel, convert to bgra, uint16 to process | |
if pix_fmt.startswith('yuva'): | |
writer_input_dict['-pix_fmt'] = 'bgra64le' | |
reader_output_dict['-pix_fmt'] = 'bgra64le' | |
elif pix_fmt.endswith('le'): | |
writer_input_dict['-pix_fmt'] = 'bgr48le' | |
reader_output_dict['-pix_fmt'] = 'bgr48le' | |
else: | |
writer_input_dict['-pix_fmt'] = 'bgr24' | |
reader_output_dict['-pix_fmt'] = 'bgr24' | |
if color_range is not None: | |
writer_output_dict['-color_range'] = color_range | |
writer_input_dict['-color_range'] = color_range | |
if color_space is not None: | |
writer_output_dict['-colorspace'] = color_space | |
writer_input_dict['-colorspace'] = color_space | |
if color_primaries is not None: | |
writer_output_dict['-color_primaries'] = color_primaries | |
writer_input_dict['-color_primaries'] = color_primaries | |
if color_transfer is not None: | |
writer_output_dict['-color_trc'] = color_transfer | |
writer_input_dict['-color_trc'] = color_transfer | |
writer_output_dict['-sws_flags'] = 'full_chroma_int+bitexact+accurate_rnd' | |
reader_output_dict['-sws_flags'] = 'full_chroma_int+bitexact+accurate_rnd' | |
# writer_input_dict['-pix_fmt'] = 'bgr48le' | |
# reader_output_dict = {'-pix_fmt': 'bgr48le'} | |
# -s 1920x1080 | |
# writer_input_dict['-s'] = '1920x1080' | |
# writer_output_dict['-s'] = '1920x1080' | |
# writer_input_dict['-s'] = '1080x1920' | |
# writer_output_dict['-s'] = '1080x1920' | |
print(writer_input_dict) | |
print(writer_output_dict) | |
self.reader = skvideo.io.FFmpegReader(video_path, outputdict=reader_output_dict) | |
else: | |
# fps = 25 | |
codec_name = 'libx264' | |
bit_rate=None | |
pix_fmt = 'yuv420p' | |
reader_output_dict = {'-r': str(fps)} | |
writer_input_dict = {'-r': str(fps)} | |
writer_output_dict = {'-pix_fmt': pix_fmt, '-r': str(fps), '-vcodec':str(codec_name)} | |
# if bit_rate is not None: | |
# writer_output_dict['-b:v'] = bit_rate | |
writer_output_dict['-crf'] = '17' | |
# if video has alpha channel, convert to bgra, uint16 to process | |
if pix_fmt.startswith('yuva'): | |
writer_input_dict['-pix_fmt'] = 'bgra64le' | |
reader_output_dict['-pix_fmt'] = 'bgra64le' | |
elif pix_fmt.endswith('le'): | |
writer_input_dict['-pix_fmt'] = 'bgr48le' | |
reader_output_dict['-pix_fmt'] = 'bgr48le' | |
else: | |
writer_input_dict['-pix_fmt'] = 'bgr24' | |
reader_output_dict['-pix_fmt'] = 'bgr24' | |
writer_output_dict['-sws_flags'] = 'full_chroma_int+bitexact+accurate_rnd' | |
print(writer_input_dict) | |
print(writer_output_dict) | |
if output_video_path is not None: | |
self.writer = skvideo.io.FFmpegWriter(output_video_path, inputdict=writer_input_dict, outputdict=writer_output_dict, verbosity=1) | |
def getframes(self): | |
return self.reader.nextFrame() | |
def writeframe(self, frame): | |
if frame is None: | |
self.writer.close() | |
else: | |
self.writer.writeFrame(frame) | |
def save_videos_from_pil(pil_images, path, fps=8): | |
save_fmt = ".mp4" | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
width, height = pil_images[0].size | |
if save_fmt == ".mp4": | |
video_cap = VideoUtils(output_video_path=path, fps=fps) | |
for pil_image in pil_images: | |
image_cv2 = np.array(pil_image)[:,:,[2,1,0]] | |
video_cap.writeframe(image_cv2) | |
video_cap.writeframe(None) | |
elif save_fmt == ".gif": | |
pil_images[0].save( | |
fp=path, | |
format="GIF", | |
append_images=pil_images[1:], | |
save_all=True, | |
duration=(1 / fps * 1000), | |
loop=0, | |
optimize=False, | |
lossless=True | |
) | |
else: | |
raise ValueError("Unsupported file type. Use .mp4 or .gif.") | |
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): | |
videos = rearrange(videos, "b c t h w -> t b c h w") | |
height, width = videos.shape[-2:] | |
outputs = [] | |
for x in videos: | |
x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w) | |
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c) | |
if rescale: | |
x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
x = (x * 255).numpy().astype(np.uint8) | |
x = Image.fromarray(x) | |
outputs.append(x) | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
save_videos_from_pil(outputs, path, fps) | |
def save_video(video, path: str, rescale=False, n_rows=6, fps=8): | |
outputs = [] | |
for x in video: | |
x = Image.fromarray(x) | |
outputs.append(x) | |
save_videos_from_pil(outputs, path, fps) |