Upload splitters.py with huggingface_hub
Browse files- splitters.py +18 -3
splitters.py
CHANGED
|
@@ -70,6 +70,7 @@ class SeparateSplit(Splitter):
|
|
| 70 |
from_split: str
|
| 71 |
to_split_names: List[str]
|
| 72 |
to_split_sizes: List[int]
|
|
|
|
| 73 |
|
| 74 |
def verify(self):
|
| 75 |
assert (
|
|
@@ -82,13 +83,14 @@ class SeparateSplit(Splitter):
|
|
| 82 |
mapping = {
|
| 83 |
key: {key: [(None, None)]}
|
| 84 |
for key in multi_stream.keys()
|
| 85 |
-
if key != self.from_split
|
| 86 |
}
|
| 87 |
so_far = 0
|
| 88 |
for name, size in itertools.zip_longest(
|
| 89 |
self.to_split_names, self.to_split_sizes
|
| 90 |
):
|
| 91 |
-
|
|
|
|
| 92 |
if size:
|
| 93 |
so_far += size
|
| 94 |
generators = slice_streams(multi_stream, mapping)
|
|
@@ -131,6 +133,14 @@ class Sampler(Artifact):
|
|
| 131 |
) -> List[Dict[str, object]]:
|
| 132 |
pass
|
| 133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
class RandomSampler(Sampler):
|
| 136 |
def sample(
|
|
@@ -172,6 +182,7 @@ class DiverseLabelsSampler(Sampler):
|
|
| 172 |
|
| 173 |
choices: str = "choices"
|
| 174 |
labels: str = "labels"
|
|
|
|
| 175 |
|
| 176 |
def prepare(self):
|
| 177 |
super().prepare()
|
|
@@ -207,6 +218,8 @@ class DiverseLabelsSampler(Sampler):
|
|
| 207 |
labels = {}
|
| 208 |
for examplar in examplars_pool:
|
| 209 |
label_repr = self.examplar_repr(examplar)
|
|
|
|
|
|
|
| 210 |
if label_repr not in labels:
|
| 211 |
labels[label_repr] = []
|
| 212 |
labels[label_repr].append(examplar)
|
|
@@ -269,7 +282,9 @@ class SpreadSplit(InstanceOperatorWithMultiStreamAccess):
|
|
| 269 |
self.local_cache = list(multi_stream[self.source_stream])
|
| 270 |
|
| 271 |
source_stream = self.local_cache
|
| 272 |
-
|
|
|
|
|
|
|
| 273 |
sampled_instances = self.sampler.sample(source_stream)
|
| 274 |
instance[self.target_field] = sampled_instances
|
| 275 |
return instance
|
|
|
|
| 70 |
from_split: str
|
| 71 |
to_split_names: List[str]
|
| 72 |
to_split_sizes: List[int]
|
| 73 |
+
remove_targets_from_source_split: bool = True
|
| 74 |
|
| 75 |
def verify(self):
|
| 76 |
assert (
|
|
|
|
| 83 |
mapping = {
|
| 84 |
key: {key: [(None, None)]}
|
| 85 |
for key in multi_stream.keys()
|
| 86 |
+
if not self.remove_targets_from_source_split or key != self.from_split
|
| 87 |
}
|
| 88 |
so_far = 0
|
| 89 |
for name, size in itertools.zip_longest(
|
| 90 |
self.to_split_names, self.to_split_sizes
|
| 91 |
):
|
| 92 |
+
if self.remove_targets_from_source_split or name != self.from_split:
|
| 93 |
+
mapping[name] = {self.from_split: [(so_far, size)]}
|
| 94 |
if size:
|
| 95 |
so_far += size
|
| 96 |
generators = slice_streams(multi_stream, mapping)
|
|
|
|
| 133 |
) -> List[Dict[str, object]]:
|
| 134 |
pass
|
| 135 |
|
| 136 |
+
def filter_source_by_instance(
|
| 137 |
+
self, instances_pool: List[Dict[str, object]], instance: Dict[str, object]
|
| 138 |
+
) -> List[Dict[str, object]]:
|
| 139 |
+
if "inputs" not in instance:
|
| 140 |
+
raise ValueError(f"'inputs' field is missing from '{instance}'.")
|
| 141 |
+
|
| 142 |
+
return list(filter(lambda x: x["inputs"] != instance["inputs"], instances_pool))
|
| 143 |
+
|
| 144 |
|
| 145 |
class RandomSampler(Sampler):
|
| 146 |
def sample(
|
|
|
|
| 182 |
|
| 183 |
choices: str = "choices"
|
| 184 |
labels: str = "labels"
|
| 185 |
+
include_empty_label: bool = True
|
| 186 |
|
| 187 |
def prepare(self):
|
| 188 |
super().prepare()
|
|
|
|
| 218 |
labels = {}
|
| 219 |
for examplar in examplars_pool:
|
| 220 |
label_repr = self.examplar_repr(examplar)
|
| 221 |
+
if label_repr == "[]" and not self.include_empty_label:
|
| 222 |
+
continue
|
| 223 |
if label_repr not in labels:
|
| 224 |
labels[label_repr] = []
|
| 225 |
labels[label_repr].append(examplar)
|
|
|
|
| 282 |
self.local_cache = list(multi_stream[self.source_stream])
|
| 283 |
|
| 284 |
source_stream = self.local_cache
|
| 285 |
+
source_stream = self.sampler.filter_source_by_instance(
|
| 286 |
+
source_stream, instance
|
| 287 |
+
)
|
| 288 |
sampled_instances = self.sampler.sample(source_stream)
|
| 289 |
instance[self.target_field] = sampled_instances
|
| 290 |
return instance
|