Senorita / dataset_demo_videos.py
PengWeixuanSZU's picture
Upload 66 files
84abaca verified
raw
history blame
3.73 kB
import os
import cv2
import decord
import torch
from torch.utils.data import Dataset
import numpy as np
class VideoDataset(Dataset):
def __init__(self, root_dir):
self.root_dir = root_dir
self.data = []
# 遍历根目录下的所有文件夹,收集数据路径
#for folder_name in os.listdir(root_dir):
folder_path = os.path.join(root_dir, "")
if os.path.isdir(folder_path):
for file_name in os.listdir(folder_path):
if 'edit' in file_name.lower() and file_name.lower().endswith('.png'):
number = file_name.split('_edit')[0]
video_file = os.path.join(folder_path, f"{number}.mp4")
png_file = os.path.join(folder_path, file_name)
txt_file = os.path.join(folder_path, f"{number}.txt")
self.data.append((png_file, video_file, txt_file))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
png_file, video_file, txt_file = self.data[idx]
try:
# 读取 PNG 文件并调整大小
image = cv2.imread(png_file)
resized_image = cv2.resize(image, (768, 448))
resized_image = cv2.cvtColor(resized_image, cv2.COLOR_BGR2RGB)
if "6.mp4" in video_file:
frame = np.zeros([448, 768+448, 3]).astype(np.uint8)
frame[:, 448//2: 448//2+768, :] = np.array(resized_image)
resized_image = cv2.resize(frame, (768, 448))
except Exception as e:
print("*"*200)
print(f"Error reading or resizing image {png_file}: {e}")
resized_image = np.zeros((448, 768, 3), dtype=np.uint8)
try:
# 读取对应的 MP4 文件
#if "6.mp4" in video_file:
vr = decord.VideoReader(video_file)
frames = vr.get_batch(list(range(33))).asnumpy()
if "6.mp4" in video_file:
resized_frames = [cv2.resize(frame, (768, 448)) for frame in frames]
for i in range(len(resized_frames)):
frame = np.zeros([448, 768+448, 3]).astype(np.uint8)
frame[:, 448//2: 448//2+768, :] = np.array(resized_frames[i])
resized_frames[i] = frame
resized_frames = [cv2.resize(frame, (768, 448)) for frame in resized_frames]
else:
resized_frames = [cv2.resize(frame, (768, 448)) for frame in frames]
except Exception as e:
print("*"*200, video_file, "*"*200)
print(f"Error reading or resizing video {video_file}: {e}")
resized_frames = [np.zeros((448, 768, 3), dtype=np.uint8) for _ in range(33)]
try:
# 读取对应的 TXT 文件
with open(txt_file, 'r') as f:
pos_prompt = f.readline().strip()
neg_prompt = f.readline().strip()
except Exception as e:
print(f"Error reading text file {txt_file}: {e}")
pos_prompt = ""
neg_prompt = ""
return {
'image': torch.from_numpy(resized_image),
'frames': torch.from_numpy(np.array(resized_frames)),
'pos_prompt': pos_prompt,
'neg_prompt': neg_prompt,
'image_path': png_file # 返回图像路径
}
"""
# 示例用法
root_dir = 'demo_videos/videos'
dataset = VideoDataset(root_dir)
# 读取第一个样本
sample = dataset[0]
if sample:
print(sample['image'].shape)
print(sample['frames'].shape)
print(sample['pos_prompt'])
print(sample['neg_prompt'])
print(sample['image_path'])
"""