Spaces:
Runtime error
Runtime error
| import glob | |
| import os | |
| import pathlib | |
| import random | |
| import re | |
| from typing import Iterator | |
| from PIL import UnidentifiedImageError | |
| from imgutils.data import load_image | |
| from .base import RootDataSource | |
| from ..model import ImageItem | |
| class LocalSource(RootDataSource): | |
| def __init__(self, directory: str, recursive: bool = True, shuffle: bool = False): | |
| self.directory = directory | |
| self.recursive = recursive | |
| self.shuffle = shuffle | |
| def _iter_files(self): | |
| if self.recursive: | |
| for directory, _, files in os.walk(self.directory): | |
| group_name = re.sub(r'[\W_]+', '_', directory).strip('_') | |
| for file in files: | |
| yield os.path.join(directory, file), group_name | |
| else: | |
| group_name = re.sub(r'[\W_]+', '_', self.directory).strip('_') | |
| for file in os.listdir(self.directory): | |
| yield os.path.join(self.directory, file), group_name | |
| def _actual_iter_files(self): | |
| lst = list(self._iter_files()) | |
| if self.shuffle: | |
| random.shuffle(lst) | |
| yield from lst | |
| def _iter(self) -> Iterator[ImageItem]: | |
| for file, group_name in self._iter_files(): | |
| try: | |
| origin_item = ImageItem.load_from_image(file) | |
| origin_item.image.load() | |
| except UnidentifiedImageError: | |
| continue | |
| meta = origin_item.meta or { | |
| 'path': os.path.abspath(file), | |
| 'group_id': group_name, | |
| 'filename': os.path.basename(file), | |
| } | |
| yield ImageItem(origin_item.image, meta) | |
| class LocalTISource(RootDataSource): | |
| def __init__(self, directory: str): | |
| self.directory = directory | |
| def _iter(self) -> Iterator[ImageItem]: | |
| group_name = re.sub(r'[\W_]+', '_', self.directory).strip('_') | |
| for f in glob.glob(os.path.join(self.directory, '*')): | |
| if not os.path.isfile(f): | |
| continue | |
| try: | |
| image = load_image(f) | |
| except UnidentifiedImageError: | |
| continue | |
| id_ = os.path.splitext(os.path.basename(f))[0] | |
| txt_file = os.path.join(self.directory, f'{id_}.txt') | |
| if os.path.exists(txt_file): | |
| full_text = pathlib.Path(txt_file).read_text(encoding='utf-8') | |
| words = re.split(r'\s*,\s*', full_text) | |
| tags = {word: 1.0 for word in words} | |
| else: | |
| tags = {} | |
| meta = { | |
| 'path': os.path.abspath(f), | |
| 'group_id': group_name, | |
| 'filename': os.path.basename(f), | |
| 'tags': tags, | |
| } | |
| yield ImageItem(image, meta) | |