Spaces:
Runtime error
Runtime error
| import json | |
| from asyncio import Event | |
| import ray | |
| from mtranslate.core import translate | |
| from ray.actor import ActorHandle | |
| from tqdm import tqdm | |
| ray.init() | |
| from typing import Tuple | |
| # Back on the local node, once you launch your remote Ray tasks, call | |
| # `print_until_done`, which will feed everything back into a `tqdm` counter. | |
| class ProgressBarActor: | |
| counter: int | |
| delta: int | |
| event: Event | |
| def __init__(self) -> None: | |
| self.counter = 0 | |
| self.delta = 0 | |
| self.event = Event() | |
| def update(self, num_items_completed: int) -> None: | |
| """Updates the ProgressBar with the incremental | |
| number of items that were just completed. | |
| """ | |
| self.counter += num_items_completed | |
| self.delta += num_items_completed | |
| self.event.set() | |
| async def wait_for_update(self) -> Tuple[int, int]: | |
| """Blocking call. | |
| Waits until somebody calls `update`, then returns a tuple of | |
| the number of updates since the last call to | |
| `wait_for_update`, and the total number of completed items. | |
| """ | |
| await self.event.wait() | |
| self.event.clear() | |
| saved_delta = self.delta | |
| self.delta = 0 | |
| return saved_delta, self.counter | |
| def get_counter(self) -> int: | |
| """ | |
| Returns the total number of complete items. | |
| """ | |
| return self.counter | |
| class ProgressBar: | |
| progress_actor: ActorHandle | |
| total: int | |
| description: str | |
| pbar: tqdm | |
| def __init__(self, total: int, description: str = ""): | |
| # Ray actors don't seem to play nice with mypy, generating | |
| # a spurious warning for the following line, | |
| # which we need to suppress. The code is fine. | |
| self.progress_actor = ProgressBarActor.remote() # type: ignore | |
| self.total = total | |
| self.description = description | |
| def actor(self) -> ActorHandle: | |
| """Returns a reference to the remote `ProgressBarActor`. | |
| When you complete tasks, call `update` on the actor. | |
| """ | |
| return self.progress_actor | |
| def print_until_done(self) -> None: | |
| """Blocking call. | |
| Do this after starting a series of remote Ray tasks, to which you've | |
| passed the actor handle. Each of them calls `update` on the actor. | |
| When the progress meter reaches 100%, this method returns. | |
| """ | |
| pbar = tqdm(desc=self.description, total=self.total) | |
| while True: | |
| delta, counter = ray.get(self.actor.wait_for_update.remote()) | |
| pbar.update(delta) | |
| if counter >= self.total: | |
| pbar.close() | |
| return | |
| with open("answer_reverse_mapping.json") as f: | |
| answer_reverse_mapping = json.load(f) | |
| def translate_answer(value, pba): | |
| temp = {} | |
| for lang in ["fr", "es", "de"]: | |
| temp.update({lang: translate(value, lang, "en")}) | |
| pba.update.remote(1) | |
| return temp | |
| translation_dicts = [] | |
| pb = ProgressBar(len(answer_reverse_mapping.values())) | |
| actor = pb.actor | |
| for value in answer_reverse_mapping.values(): | |
| translation_dicts.append(translate_answer.remote(value, actor)) | |
| pb.print_until_done() | |
| translation_dict = dict( | |
| zip(answer_reverse_mapping.values(), ray.get(translation_dicts)) | |
| ) | |
| with open("translation_dict.json", "w") as f: | |
| json.dump(translation_dict, f) | |