Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
from dataclasses import dataclass, field | |
from functools import partial | |
import pyarrow as pa | |
import uroman | |
from stopes.modules.partitioned_data_mapper import BatchMapper | |
from stopes.utils.arrow_utils import apply_on_nested_array, apply_over_groups | |
from align_utils import get_uroman_tokens | |
from text_normalization import text_normalize | |
class LangColumnConfig: | |
column: str | |
lang_value: str | None = None | |
lang_column: str | None = None | |
class TextRomanizerConfig: | |
column: LangColumnConfig = field( | |
default_factory=lambda: LangColumnConfig(column="text", lang_value="en") | |
) | |
def __post_init__(self): | |
assert self.column.lang_value is not None or self.column.lang_column is not None | |
class TextRomanizer(BatchMapper): | |
def __init__(self, config: TextRomanizerConfig): | |
super().__init__(config) | |
self.uroman = uroman.Uroman() | |
def _apply_on_unique_lang_table( | |
self, table: pa.Table, config: LangColumnConfig | |
) -> pa.Table: | |
if config.lang_column: | |
assert ( | |
len(table[config.lang_column].unique()) == 1 | |
), "this method should be called only for unique lang values" | |
lang_value = table[config.lang_column][0].as_py() | |
else: | |
lang_value = config.lang_value | |
try: | |
col = table[config.column] | |
except KeyError: | |
# `table.flatten()` allows to access fields from stuct directly | |
# with the following name: `{column_name}.{struct_field_name}` | |
col = table.flatten()[config.column] | |
normalized_texts, tokens = apply_on_nested_array( | |
partial(self._apply_on_simple_column, lang_value=lang_value), | |
col, | |
) | |
table = table.append_column(f"{config.column}_normalized", normalized_texts) | |
table = table.append_column(f"{config.column}_tokens", tokens) | |
return table | |
def _apply_on_simple_column(self, col: pa.Array | pa.ChunkedArray, lang_value: str): | |
texts = col.to_pandas().tolist() | |
normalized_texts = [text_normalize(text.strip(), lang_value) for text in texts] | |
tokens = get_uroman_tokens(normalized_texts, self.uroman, lang_value) | |
return pa.array(normalized_texts, type=pa.large_string()), pa.array( | |
tokens, type=pa.large_string() | |
) | |
def __call__(self, table: pa.Table | None) -> pa.Table | None: | |
if table is None: | |
return table | |
column_config = self.config.column | |
table = apply_over_groups( | |
table, | |
[ | |
column_config.lang_column | |
], # note that if `current_config.lang_column` is None function will be applied on the full table | |
partial(self._apply_on_unique_lang_table, config=column_config), | |
) | |
return table | |