# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field from functools import partial import pyarrow as pa import uroman from stopes.modules.partitioned_data_mapper import BatchMapper from stopes.utils.arrow_utils import apply_on_nested_array, apply_over_groups from align_utils import get_uroman_tokens from text_normalization import text_normalize @dataclass class LangColumnConfig: column: str lang_value: str | None = None lang_column: str | None = None @dataclass(kw_only=True) class TextRomanizerConfig: column: LangColumnConfig = field( default_factory=lambda: LangColumnConfig(column="text", lang_value="en") ) def __post_init__(self): assert self.column.lang_value is not None or self.column.lang_column is not None class TextRomanizer(BatchMapper): def __init__(self, config: TextRomanizerConfig): super().__init__(config) self.uroman = uroman.Uroman() def _apply_on_unique_lang_table( self, table: pa.Table, config: LangColumnConfig ) -> pa.Table: if config.lang_column: assert ( len(table[config.lang_column].unique()) == 1 ), "this method should be called only for unique lang values" lang_value = table[config.lang_column][0].as_py() else: lang_value = config.lang_value try: col = table[config.column] except KeyError: # `table.flatten()` allows to access fields from stuct directly # with the following name: `{column_name}.{struct_field_name}` col = table.flatten()[config.column] normalized_texts, tokens = apply_on_nested_array( partial(self._apply_on_simple_column, lang_value=lang_value), col, ) table = table.append_column(f"{config.column}_normalized", normalized_texts) table = table.append_column(f"{config.column}_tokens", tokens) return table def _apply_on_simple_column(self, col: pa.Array | pa.ChunkedArray, lang_value: str): texts = col.to_pandas().tolist() normalized_texts = [text_normalize(text.strip(), lang_value) for text in texts] tokens = get_uroman_tokens(normalized_texts, self.uroman, lang_value) return pa.array(normalized_texts, type=pa.large_string()), pa.array( tokens, type=pa.large_string() ) def __call__(self, table: pa.Table | None) -> pa.Table | None: if table is None: return table column_config = self.config.column table = apply_over_groups( table, [ column_config.lang_column ], # note that if `current_config.lang_column` is None function will be applied on the full table partial(self._apply_on_unique_lang_table, config=column_config), ) return table