import pathlib import tempfile import unittest import torch from PIL import Image from finetrainers.data import ( ImageCaptionFilePairDataset, ImageFileCaptionFileListDataset, ImageFolderDataset, ValidationDataset, VideoCaptionFilePairDataset, VideoFileCaptionFileListDataset, VideoFolderDataset, VideoWebDataset, initialize_dataset, ) from finetrainers.utils import find_files from .utils import create_dummy_directory_structure class DatasetTesterMixin: num_data_files = None directory_structure = None caption = "A cat ruling the world" metadata_extension = None def setUp(self): if self.num_data_files is None: raise ValueError("num_data_files is not defined") if self.directory_structure is None: raise ValueError("dataset_structure is not defined") self.tmpdir = tempfile.TemporaryDirectory() create_dummy_directory_structure( self.directory_structure, self.tmpdir, self.num_data_files, self.caption, self.metadata_extension ) def tearDown(self): self.tmpdir.cleanup() class ImageDatasetTesterMixin(DatasetTesterMixin): metadata_extension = "jpg" class VideoDatasetTesterMixin(DatasetTesterMixin): metadata_extension = "mp4" class ImageCaptionFilePairDatasetFastTests(ImageDatasetTesterMixin, unittest.TestCase): num_data_files = 3 directory_structure = [ "0.jpg", "1.jpg", "2.jpg", "0.txt", "1.txt", "2.txt", ] def setUp(self): super().setUp() self.dataset = ImageCaptionFilePairDataset(self.tmpdir.name, infinite=False) def test_getitem(self): iterator = iter(self.dataset) for _ in range(self.num_data_files): item = next(iterator) self.assertEqual(item["caption"], self.caption) self.assertTrue(torch.is_tensor(item["image"])) self.assertEqual(item["image"].shape, (3, 64, 64)) def test_initialize_dataset(self): dataset = initialize_dataset(self.tmpdir.name, "image", infinite=False) self.assertIsInstance(dataset, ImageCaptionFilePairDataset) class ImageFileCaptionFileListDatasetFastTests(ImageDatasetTesterMixin, unittest.TestCase): num_data_files = 3 directory_structure = [ "prompts.txt", "images.txt", "images/", "images/0.jpg", "images/1.jpg", "images/2.jpg", ] def setUp(self): super().setUp() self.dataset = ImageFileCaptionFileListDataset(self.tmpdir.name, infinite=False) def test_getitem(self): iterator = iter(self.dataset) for i in range(3): item = next(iterator) self.assertEqual(item["caption"], self.caption) self.assertTrue(torch.is_tensor(item["image"])) self.assertEqual(item["image"].shape, (3, 64, 64)) def test_initialize_dataset(self): dataset = initialize_dataset(self.tmpdir.name, "image", infinite=False) self.assertIsInstance(dataset, ImageFileCaptionFileListDataset) class ImageFolderDatasetFastTests___CSV(ImageDatasetTesterMixin, unittest.TestCase): num_data_files = 3 directory_structure = [ "metadata.csv", "0.jpg", "1.jpg", "2.jpg", ] def setUp(self): super().setUp() self.dataset = ImageFolderDataset(self.tmpdir.name, infinite=False) def test_getitem(self): iterator = iter(self.dataset) for _ in range(3): item = next(iterator) self.assertIn("caption", item) self.assertEqual(item["caption"], self.caption) self.assertTrue(torch.is_tensor(item["image"])) def test_initialize_dataset(self): dataset = initialize_dataset(self.tmpdir.name, "image", infinite=False) self.assertIsInstance(dataset, ImageFolderDataset) class ImageFolderDatasetFastTests___JSONL(ImageDatasetTesterMixin, unittest.TestCase): num_data_files = 3 directory_structure = [ "metadata.jsonl", "0.jpg", "1.jpg", "2.jpg", ] def setUp(self): super().setUp() self.dataset = ImageFolderDataset(self.tmpdir.name, infinite=False) def test_getitem(self): iterator = iter(self.dataset) for _ in range(3): item = next(iterator) self.assertIn("caption", item) self.assertEqual(item["caption"], self.caption) self.assertTrue(torch.is_tensor(item["image"])) def test_initialize_dataset(self): dataset = initialize_dataset(self.tmpdir.name, "image", infinite=False) self.assertIsInstance(dataset, ImageFolderDataset) class VideoCaptionFilePairDatasetFastTests(VideoDatasetTesterMixin, unittest.TestCase): num_data_files = 3 directory_structure = [ "0.mp4", "1.mp4", "2.mp4", "0.txt", "1.txt", "2.txt", ] def setUp(self): super().setUp() self.dataset = VideoCaptionFilePairDataset(self.tmpdir.name, infinite=False) def test_getitem(self): iterator = iter(self.dataset) for _ in range(self.num_data_files): item = next(iterator) self.assertEqual(item["caption"], self.caption) self.assertTrue(torch.is_tensor(item["video"])) self.assertEqual(len(item["video"]), 4) self.assertEqual(item["video"][0].shape, (3, 64, 64)) def test_initialize_dataset(self): dataset = initialize_dataset(self.tmpdir.name, "video", infinite=False) self.assertIsInstance(dataset, VideoCaptionFilePairDataset) class VideoFileCaptionFileListDatasetFastTests(VideoDatasetTesterMixin, unittest.TestCase): num_data_files = 3 directory_structure = [ "prompts.txt", "videos.txt", "videos/", "videos/0.mp4", "videos/1.mp4", "videos/2.mp4", ] def setUp(self): super().setUp() self.dataset = VideoFileCaptionFileListDataset(self.tmpdir.name, infinite=False) def test_getitem(self): iterator = iter(self.dataset) for _ in range(3): item = next(iterator) self.assertEqual(item["caption"], self.caption) self.assertTrue(torch.is_tensor(item["video"])) self.assertEqual(len(item["video"]), 4) self.assertEqual(item["video"][0].shape, (3, 64, 64)) def test_initialize_dataset(self): dataset = initialize_dataset(self.tmpdir.name, "video", infinite=False) self.assertIsInstance(dataset, VideoFileCaptionFileListDataset) class VideoFolderDatasetFastTests___CSV(VideoDatasetTesterMixin, unittest.TestCase): num_data_files = 3 directory_structure = [ "metadata.csv", "0.mp4", "1.mp4", "2.mp4", ] def setUp(self): super().setUp() self.dataset = VideoFolderDataset(self.tmpdir.name, infinite=False) def test_getitem(self): iterator = iter(self.dataset) for _ in range(3): item = next(iterator) self.assertIn("caption", item) self.assertEqual(item["caption"], self.caption) self.assertTrue(torch.is_tensor(item["video"])) self.assertEqual(len(item["video"]), 4) self.assertEqual(item["video"][0].shape, (3, 64, 64)) def test_initialize_dataset(self): dataset = initialize_dataset(self.tmpdir.name, "video", infinite=False) self.assertIsInstance(dataset, VideoFolderDataset) class VideoFolderDatasetFastTests___JSONL(VideoDatasetTesterMixin, unittest.TestCase): num_data_files = 3 directory_structure = [ "metadata.jsonl", "0.mp4", "1.mp4", "2.mp4", ] def setUp(self): super().setUp() self.dataset = VideoFolderDataset(self.tmpdir.name, infinite=False) def test_getitem(self): iterator = iter(self.dataset) for _ in range(3): item = next(iterator) self.assertIn("caption", item) self.assertEqual(item["caption"], self.caption) self.assertTrue(torch.is_tensor(item["video"])) self.assertEqual(len(item["video"]), 4) self.assertEqual(item["video"][0].shape, (3, 64, 64)) def test_initialize_dataset(self): dataset = initialize_dataset(self.tmpdir.name, "video", infinite=False) self.assertIsInstance(dataset, VideoFolderDataset) class ImageWebDatasetFastTests(unittest.TestCase): # TODO(aryan): setup a dummy dataset pass class VideoWebDatasetFastTests(unittest.TestCase): def setUp(self): self.num_data_files = 15 self.dataset = VideoWebDataset("finetrainers/dummy-squish-wds", infinite=False) def test_getitem(self): for index, item in enumerate(self.dataset): if index > 2: break self.assertIn("caption", item) self.assertIn("video", item) self.assertTrue(torch.is_tensor(item["video"])) self.assertEqual(len(item["video"]), 121) self.assertEqual(item["video"][0].shape, (3, 720, 1280)) def test_initialize_dataset(self): dataset = initialize_dataset("finetrainers/dummy-squish-wds", "video", infinite=False) self.assertIsInstance(dataset, VideoWebDataset) class DatasetUtilsFastTests(unittest.TestCase): def test_find_files_depth_0(self): with tempfile.TemporaryDirectory() as tmpdir: file1 = tempfile.NamedTemporaryFile(dir=tmpdir, suffix=".txt", delete=False) file2 = tempfile.NamedTemporaryFile(dir=tmpdir, suffix=".txt", delete=False) file3 = tempfile.NamedTemporaryFile(dir=tmpdir, suffix=".txt", delete=False) files = find_files(tmpdir, "*.txt") self.assertEqual(len(files), 3) self.assertIn(file1.name, files) self.assertIn(file2.name, files) self.assertIn(file3.name, files) def test_find_files_depth_n(self): with tempfile.TemporaryDirectory() as tmpdir: dir1 = tempfile.TemporaryDirectory(dir=tmpdir) dir2 = tempfile.TemporaryDirectory(dir=dir1.name) file1 = tempfile.NamedTemporaryFile(dir=dir1.name, suffix=".txt", delete=False) file2 = tempfile.NamedTemporaryFile(dir=dir2.name, suffix=".txt", delete=False) files = find_files(tmpdir, "*.txt", depth=1) self.assertEqual(len(files), 1) self.assertIn(file1.name, files) self.assertNotIn(file2.name, files) files = find_files(tmpdir, "*.txt", depth=2) self.assertEqual(len(files), 2) self.assertIn(file1.name, files) self.assertIn(file2.name, files) self.assertNotIn(dir1.name, files) self.assertNotIn(dir2.name, files) class ValidationDatasetFastTests(unittest.TestCase): def setUp(self): num_data_files = 3 self.tmpdir = tempfile.TemporaryDirectory() metadata_filename = pathlib.Path(self.tmpdir.name) / "metadata.csv" with open(metadata_filename, "w") as f: f.write("caption,image_path,video_path\n") for i in range(num_data_files): Image.new("RGB", (64, 64)).save((pathlib.Path(self.tmpdir.name) / f"{i}.jpg").as_posix()) f.write(f"test caption,{self.tmpdir.name}/{i}.jpg,\n") self.dataset = ValidationDataset(metadata_filename.as_posix()) def tearDown(self): self.tmpdir.cleanup() def test_getitem(self): for i, data in enumerate(self.dataset): self.assertEqual(data["image_path"], f"{self.tmpdir.name}/{i}.jpg") self.assertIsInstance(data["image"], Image.Image) self.assertEqual(data["image"].size, (64, 64))