jbilcke-hf's picture
jbilcke-hf HF Staff
we are going to hack into finetrainers
9fd1204
import pathlib
import tempfile
import unittest
import torch
from PIL import Image
from finetrainers.data import (
ImageCaptionFilePairDataset,
ImageFileCaptionFileListDataset,
ImageFolderDataset,
ValidationDataset,
VideoCaptionFilePairDataset,
VideoFileCaptionFileListDataset,
VideoFolderDataset,
VideoWebDataset,
initialize_dataset,
)
from finetrainers.utils import find_files
from .utils import create_dummy_directory_structure
class DatasetTesterMixin:
num_data_files = None
directory_structure = None
caption = "A cat ruling the world"
metadata_extension = None
def setUp(self):
if self.num_data_files is None:
raise ValueError("num_data_files is not defined")
if self.directory_structure is None:
raise ValueError("dataset_structure is not defined")
self.tmpdir = tempfile.TemporaryDirectory()
create_dummy_directory_structure(
self.directory_structure, self.tmpdir, self.num_data_files, self.caption, self.metadata_extension
)
def tearDown(self):
self.tmpdir.cleanup()
class ImageDatasetTesterMixin(DatasetTesterMixin):
metadata_extension = "jpg"
class VideoDatasetTesterMixin(DatasetTesterMixin):
metadata_extension = "mp4"
class ImageCaptionFilePairDatasetFastTests(ImageDatasetTesterMixin, unittest.TestCase):
num_data_files = 3
directory_structure = [
"0.jpg",
"1.jpg",
"2.jpg",
"0.txt",
"1.txt",
"2.txt",
]
def setUp(self):
super().setUp()
self.dataset = ImageCaptionFilePairDataset(self.tmpdir.name, infinite=False)
def test_getitem(self):
iterator = iter(self.dataset)
for _ in range(self.num_data_files):
item = next(iterator)
self.assertEqual(item["caption"], self.caption)
self.assertTrue(torch.is_tensor(item["image"]))
self.assertEqual(item["image"].shape, (3, 64, 64))
def test_initialize_dataset(self):
dataset = initialize_dataset(self.tmpdir.name, "image", infinite=False)
self.assertIsInstance(dataset, ImageCaptionFilePairDataset)
class ImageFileCaptionFileListDatasetFastTests(ImageDatasetTesterMixin, unittest.TestCase):
num_data_files = 3
directory_structure = [
"prompts.txt",
"images.txt",
"images/",
"images/0.jpg",
"images/1.jpg",
"images/2.jpg",
]
def setUp(self):
super().setUp()
self.dataset = ImageFileCaptionFileListDataset(self.tmpdir.name, infinite=False)
def test_getitem(self):
iterator = iter(self.dataset)
for i in range(3):
item = next(iterator)
self.assertEqual(item["caption"], self.caption)
self.assertTrue(torch.is_tensor(item["image"]))
self.assertEqual(item["image"].shape, (3, 64, 64))
def test_initialize_dataset(self):
dataset = initialize_dataset(self.tmpdir.name, "image", infinite=False)
self.assertIsInstance(dataset, ImageFileCaptionFileListDataset)
class ImageFolderDatasetFastTests___CSV(ImageDatasetTesterMixin, unittest.TestCase):
num_data_files = 3
directory_structure = [
"metadata.csv",
"0.jpg",
"1.jpg",
"2.jpg",
]
def setUp(self):
super().setUp()
self.dataset = ImageFolderDataset(self.tmpdir.name, infinite=False)
def test_getitem(self):
iterator = iter(self.dataset)
for _ in range(3):
item = next(iterator)
self.assertIn("caption", item)
self.assertEqual(item["caption"], self.caption)
self.assertTrue(torch.is_tensor(item["image"]))
def test_initialize_dataset(self):
dataset = initialize_dataset(self.tmpdir.name, "image", infinite=False)
self.assertIsInstance(dataset, ImageFolderDataset)
class ImageFolderDatasetFastTests___JSONL(ImageDatasetTesterMixin, unittest.TestCase):
num_data_files = 3
directory_structure = [
"metadata.jsonl",
"0.jpg",
"1.jpg",
"2.jpg",
]
def setUp(self):
super().setUp()
self.dataset = ImageFolderDataset(self.tmpdir.name, infinite=False)
def test_getitem(self):
iterator = iter(self.dataset)
for _ in range(3):
item = next(iterator)
self.assertIn("caption", item)
self.assertEqual(item["caption"], self.caption)
self.assertTrue(torch.is_tensor(item["image"]))
def test_initialize_dataset(self):
dataset = initialize_dataset(self.tmpdir.name, "image", infinite=False)
self.assertIsInstance(dataset, ImageFolderDataset)
class VideoCaptionFilePairDatasetFastTests(VideoDatasetTesterMixin, unittest.TestCase):
num_data_files = 3
directory_structure = [
"0.mp4",
"1.mp4",
"2.mp4",
"0.txt",
"1.txt",
"2.txt",
]
def setUp(self):
super().setUp()
self.dataset = VideoCaptionFilePairDataset(self.tmpdir.name, infinite=False)
def test_getitem(self):
iterator = iter(self.dataset)
for _ in range(self.num_data_files):
item = next(iterator)
self.assertEqual(item["caption"], self.caption)
self.assertTrue(torch.is_tensor(item["video"]))
self.assertEqual(len(item["video"]), 4)
self.assertEqual(item["video"][0].shape, (3, 64, 64))
def test_initialize_dataset(self):
dataset = initialize_dataset(self.tmpdir.name, "video", infinite=False)
self.assertIsInstance(dataset, VideoCaptionFilePairDataset)
class VideoFileCaptionFileListDatasetFastTests(VideoDatasetTesterMixin, unittest.TestCase):
num_data_files = 3
directory_structure = [
"prompts.txt",
"videos.txt",
"videos/",
"videos/0.mp4",
"videos/1.mp4",
"videos/2.mp4",
]
def setUp(self):
super().setUp()
self.dataset = VideoFileCaptionFileListDataset(self.tmpdir.name, infinite=False)
def test_getitem(self):
iterator = iter(self.dataset)
for _ in range(3):
item = next(iterator)
self.assertEqual(item["caption"], self.caption)
self.assertTrue(torch.is_tensor(item["video"]))
self.assertEqual(len(item["video"]), 4)
self.assertEqual(item["video"][0].shape, (3, 64, 64))
def test_initialize_dataset(self):
dataset = initialize_dataset(self.tmpdir.name, "video", infinite=False)
self.assertIsInstance(dataset, VideoFileCaptionFileListDataset)
class VideoFolderDatasetFastTests___CSV(VideoDatasetTesterMixin, unittest.TestCase):
num_data_files = 3
directory_structure = [
"metadata.csv",
"0.mp4",
"1.mp4",
"2.mp4",
]
def setUp(self):
super().setUp()
self.dataset = VideoFolderDataset(self.tmpdir.name, infinite=False)
def test_getitem(self):
iterator = iter(self.dataset)
for _ in range(3):
item = next(iterator)
self.assertIn("caption", item)
self.assertEqual(item["caption"], self.caption)
self.assertTrue(torch.is_tensor(item["video"]))
self.assertEqual(len(item["video"]), 4)
self.assertEqual(item["video"][0].shape, (3, 64, 64))
def test_initialize_dataset(self):
dataset = initialize_dataset(self.tmpdir.name, "video", infinite=False)
self.assertIsInstance(dataset, VideoFolderDataset)
class VideoFolderDatasetFastTests___JSONL(VideoDatasetTesterMixin, unittest.TestCase):
num_data_files = 3
directory_structure = [
"metadata.jsonl",
"0.mp4",
"1.mp4",
"2.mp4",
]
def setUp(self):
super().setUp()
self.dataset = VideoFolderDataset(self.tmpdir.name, infinite=False)
def test_getitem(self):
iterator = iter(self.dataset)
for _ in range(3):
item = next(iterator)
self.assertIn("caption", item)
self.assertEqual(item["caption"], self.caption)
self.assertTrue(torch.is_tensor(item["video"]))
self.assertEqual(len(item["video"]), 4)
self.assertEqual(item["video"][0].shape, (3, 64, 64))
def test_initialize_dataset(self):
dataset = initialize_dataset(self.tmpdir.name, "video", infinite=False)
self.assertIsInstance(dataset, VideoFolderDataset)
class ImageWebDatasetFastTests(unittest.TestCase):
# TODO(aryan): setup a dummy dataset
pass
class VideoWebDatasetFastTests(unittest.TestCase):
def setUp(self):
self.num_data_files = 15
self.dataset = VideoWebDataset("finetrainers/dummy-squish-wds", infinite=False)
def test_getitem(self):
for index, item in enumerate(self.dataset):
if index > 2:
break
self.assertIn("caption", item)
self.assertIn("video", item)
self.assertTrue(torch.is_tensor(item["video"]))
self.assertEqual(len(item["video"]), 121)
self.assertEqual(item["video"][0].shape, (3, 720, 1280))
def test_initialize_dataset(self):
dataset = initialize_dataset("finetrainers/dummy-squish-wds", "video", infinite=False)
self.assertIsInstance(dataset, VideoWebDataset)
class DatasetUtilsFastTests(unittest.TestCase):
def test_find_files_depth_0(self):
with tempfile.TemporaryDirectory() as tmpdir:
file1 = tempfile.NamedTemporaryFile(dir=tmpdir, suffix=".txt", delete=False)
file2 = tempfile.NamedTemporaryFile(dir=tmpdir, suffix=".txt", delete=False)
file3 = tempfile.NamedTemporaryFile(dir=tmpdir, suffix=".txt", delete=False)
files = find_files(tmpdir, "*.txt")
self.assertEqual(len(files), 3)
self.assertIn(file1.name, files)
self.assertIn(file2.name, files)
self.assertIn(file3.name, files)
def test_find_files_depth_n(self):
with tempfile.TemporaryDirectory() as tmpdir:
dir1 = tempfile.TemporaryDirectory(dir=tmpdir)
dir2 = tempfile.TemporaryDirectory(dir=dir1.name)
file1 = tempfile.NamedTemporaryFile(dir=dir1.name, suffix=".txt", delete=False)
file2 = tempfile.NamedTemporaryFile(dir=dir2.name, suffix=".txt", delete=False)
files = find_files(tmpdir, "*.txt", depth=1)
self.assertEqual(len(files), 1)
self.assertIn(file1.name, files)
self.assertNotIn(file2.name, files)
files = find_files(tmpdir, "*.txt", depth=2)
self.assertEqual(len(files), 2)
self.assertIn(file1.name, files)
self.assertIn(file2.name, files)
self.assertNotIn(dir1.name, files)
self.assertNotIn(dir2.name, files)
class ValidationDatasetFastTests(unittest.TestCase):
def setUp(self):
num_data_files = 3
self.tmpdir = tempfile.TemporaryDirectory()
metadata_filename = pathlib.Path(self.tmpdir.name) / "metadata.csv"
with open(metadata_filename, "w") as f:
f.write("caption,image_path,video_path\n")
for i in range(num_data_files):
Image.new("RGB", (64, 64)).save((pathlib.Path(self.tmpdir.name) / f"{i}.jpg").as_posix())
f.write(f"test caption,{self.tmpdir.name}/{i}.jpg,\n")
self.dataset = ValidationDataset(metadata_filename.as_posix())
def tearDown(self):
self.tmpdir.cleanup()
def test_getitem(self):
for i, data in enumerate(self.dataset):
self.assertEqual(data["image_path"], f"{self.tmpdir.name}/{i}.jpg")
self.assertIsInstance(data["image"], Image.Image)
self.assertEqual(data["image"].size, (64, 64))