Upload splitters.py with huggingface_hub
Browse files- splitters.py +51 -0
splitters.py
CHANGED
|
@@ -102,6 +102,57 @@ class RandomSampler(Sampler):
|
|
| 102 |
return random.sample(instances_pool, self.sample_size)
|
| 103 |
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
class SpreadSplit(InstanceOperatorWithGlobalAccess):
|
| 106 |
source_stream: str = None
|
| 107 |
target_field: str = None
|
|
|
|
| 102 |
return random.sample(instances_pool, self.sample_size)
|
| 103 |
|
| 104 |
|
| 105 |
+
class DiverseLabelsSampler(Sampler):
|
| 106 |
+
choices: str = "choices"
|
| 107 |
+
|
| 108 |
+
def prepare(self):
|
| 109 |
+
super().prepare()
|
| 110 |
+
self.labels = None
|
| 111 |
+
|
| 112 |
+
def examplar_repr(self, examplar):
|
| 113 |
+
assert (
|
| 114 |
+
"inputs" in examplar and self.choices in examplar["inputs"]
|
| 115 |
+
), f"DiverseLabelsSampler assumes each examplar has {self.choices} field in it input"
|
| 116 |
+
examplar_outputs = next(iter(examplar["outputs"].values()))
|
| 117 |
+
return str([choice for choice in examplar["inputs"][self.choices] if choice in examplar_outputs])
|
| 118 |
+
|
| 119 |
+
def divide_by_repr(self, examplars_pool):
|
| 120 |
+
labels = dict()
|
| 121 |
+
for examplar in examplars_pool:
|
| 122 |
+
label_repr = self.examplar_repr(examplar)
|
| 123 |
+
if label_repr not in labels:
|
| 124 |
+
labels[label_repr] = []
|
| 125 |
+
labels[label_repr].append(examplar)
|
| 126 |
+
return labels
|
| 127 |
+
|
| 128 |
+
def sample(self, instances_pool: List[Dict[str, object]]) -> List[Dict[str, object]]:
|
| 129 |
+
if self.labels is None:
|
| 130 |
+
self.labels = self.divide_by_repr(instances_pool)
|
| 131 |
+
all_labels = list(self.labels.keys())
|
| 132 |
+
random.shuffle(all_labels)
|
| 133 |
+
from collections import Counter
|
| 134 |
+
|
| 135 |
+
total_allocated = 0
|
| 136 |
+
allocations = Counter()
|
| 137 |
+
|
| 138 |
+
while total_allocated < self.sample_size:
|
| 139 |
+
for label in all_labels:
|
| 140 |
+
if total_allocated < self.sample_size:
|
| 141 |
+
if len(self.labels[label]) - allocations[label] > 0:
|
| 142 |
+
allocations[label] += 1
|
| 143 |
+
total_allocated += 1
|
| 144 |
+
else:
|
| 145 |
+
break
|
| 146 |
+
|
| 147 |
+
result = []
|
| 148 |
+
for label, allocation in allocations.items():
|
| 149 |
+
sample = random.sample(self.labels[label], allocation)
|
| 150 |
+
result.extend(sample)
|
| 151 |
+
|
| 152 |
+
random.shuffle(result)
|
| 153 |
+
return result
|
| 154 |
+
|
| 155 |
+
|
| 156 |
class SpreadSplit(InstanceOperatorWithGlobalAccess):
|
| 157 |
source_stream: str = None
|
| 158 |
target_field: str = None
|