from typing import Optional, List from datasets import load_dataset import random class PromptLoader: def __init__(self, seed: int = 42) -> None: self.randomizer = random.Random(seed) self.data: Optional[List[str]] = None def _load_data(self) -> None: self.data = load_dataset("daspartho/stable-diffusion-prompts")["train"][ "prompt" ] def load_data(self, size: Optional[int] = None) -> List[str]: if not self.data: self._load_data() if size: if size > len(self.data): raise ValueError("Not enough samples available!") return self.randomizer.sample(self.data, size) else: return self.data