Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| import os | |
| import cv2 | |
| import decord | |
| import torch | |
| from torch.utils.data import Dataset | |
| import numpy as np | |
| class VideoDataset(Dataset): | |
| def __init__(self, root_dir): | |
| self.root_dir = root_dir | |
| self.data = [] | |
| # 遍历根目录下的所有文件夹,收集数据路径 | |
| #for folder_name in os.listdir(root_dir): | |
| folder_path = os.path.join(root_dir, "") | |
| if os.path.isdir(folder_path): | |
| for file_name in os.listdir(folder_path): | |
| if 'edit' in file_name.lower() and file_name.lower().endswith('.png'): | |
| number = file_name.split('_edit')[0] | |
| video_file = os.path.join(folder_path, f"{number}.mp4") | |
| png_file = os.path.join(folder_path, file_name) | |
| txt_file = os.path.join(folder_path, f"{number}.txt") | |
| self.data.append((png_file, video_file, txt_file)) | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| png_file, video_file, txt_file = self.data[idx] | |
| try: | |
| # 读取 PNG 文件并调整大小 | |
| image = cv2.imread(png_file) | |
| resized_image = cv2.resize(image, (768, 448)) | |
| resized_image = cv2.cvtColor(resized_image, cv2.COLOR_BGR2RGB) | |
| if "6.mp4" in video_file: | |
| frame = np.zeros([448, 768+448, 3]).astype(np.uint8) | |
| frame[:, 448//2: 448//2+768, :] = np.array(resized_image) | |
| resized_image = cv2.resize(frame, (768, 448)) | |
| except Exception as e: | |
| print("*"*200) | |
| print(f"Error reading or resizing image {png_file}: {e}") | |
| resized_image = np.zeros((448, 768, 3), dtype=np.uint8) | |
| try: | |
| # 读取对应的 MP4 文件 | |
| #if "6.mp4" in video_file: | |
| vr = decord.VideoReader(video_file) | |
| frames = vr.get_batch(list(range(33))).asnumpy() | |
| if "6.mp4" in video_file: | |
| resized_frames = [cv2.resize(frame, (768, 448)) for frame in frames] | |
| for i in range(len(resized_frames)): | |
| frame = np.zeros([448, 768+448, 3]).astype(np.uint8) | |
| frame[:, 448//2: 448//2+768, :] = np.array(resized_frames[i]) | |
| resized_frames[i] = frame | |
| resized_frames = [cv2.resize(frame, (768, 448)) for frame in resized_frames] | |
| else: | |
| resized_frames = [cv2.resize(frame, (768, 448)) for frame in frames] | |
| except Exception as e: | |
| print("*"*200, video_file, "*"*200) | |
| print(f"Error reading or resizing video {video_file}: {e}") | |
| resized_frames = [np.zeros((448, 768, 3), dtype=np.uint8) for _ in range(33)] | |
| try: | |
| # 读取对应的 TXT 文件 | |
| with open(txt_file, 'r') as f: | |
| pos_prompt = f.readline().strip() | |
| neg_prompt = f.readline().strip() | |
| except Exception as e: | |
| print(f"Error reading text file {txt_file}: {e}") | |
| pos_prompt = "" | |
| neg_prompt = "" | |
| return { | |
| 'image': torch.from_numpy(resized_image), | |
| 'frames': torch.from_numpy(np.array(resized_frames)), | |
| 'pos_prompt': pos_prompt, | |
| 'neg_prompt': neg_prompt, | |
| 'image_path': png_file # 返回图像路径 | |
| } | |
| """ | |
| # 示例用法 | |
| root_dir = 'demo_videos/videos' | |
| dataset = VideoDataset(root_dir) | |
| # 读取第一个样本 | |
| sample = dataset[0] | |
| if sample: | |
| print(sample['image'].shape) | |
| print(sample['frames'].shape) | |
| print(sample['pos_prompt']) | |
| print(sample['neg_prompt']) | |
| print(sample['image_path']) | |
| """ | |
