Upload fusion.py with huggingface_hub
Browse files
fusion.py
CHANGED
|
@@ -4,7 +4,7 @@ from typing import Generator, List, Optional
|
|
| 4 |
|
| 5 |
from .dataclass import NonPositionalField
|
| 6 |
from .operator import SourceOperator, StreamSource
|
| 7 |
-
from .random_utils import
|
| 8 |
from .stream import MultiStream, Stream
|
| 9 |
|
| 10 |
|
|
@@ -89,10 +89,13 @@ class WeightedFusion(BaseFusion):
|
|
| 89 |
weights = copy.deepcopy(self.weights)
|
| 90 |
iterators = [iter(origin()[split]) for origin in self.origins]
|
| 91 |
total_examples = 0
|
|
|
|
| 92 |
while (
|
| 93 |
self.max_total_examples is None or total_examples <= self.max_total_examples
|
| 94 |
) and len(iterators) > 0:
|
| 95 |
-
iterator =
|
|
|
|
|
|
|
| 96 |
try:
|
| 97 |
yield next(iterator)
|
| 98 |
total_examples += 1
|
|
|
|
| 4 |
|
| 5 |
from .dataclass import NonPositionalField
|
| 6 |
from .operator import SourceOperator, StreamSource
|
| 7 |
+
from .random_utils import new_random_generator
|
| 8 |
from .stream import MultiStream, Stream
|
| 9 |
|
| 10 |
|
|
|
|
| 89 |
weights = copy.deepcopy(self.weights)
|
| 90 |
iterators = [iter(origin()[split]) for origin in self.origins]
|
| 91 |
total_examples = 0
|
| 92 |
+
random_generator = new_random_generator(sub_seed="weighted_fusion_" + split)
|
| 93 |
while (
|
| 94 |
self.max_total_examples is None or total_examples <= self.max_total_examples
|
| 95 |
) and len(iterators) > 0:
|
| 96 |
+
iterator = random_generator.choices(population=iterators, weights=weights)[
|
| 97 |
+
0
|
| 98 |
+
]
|
| 99 |
try:
|
| 100 |
yield next(iterator)
|
| 101 |
total_examples += 1
|