mms-transcription / server /uromanizer.py
EC2 Default User
Initial Transcription Commit
38818c3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from functools import partial
import pyarrow as pa
import uroman
from stopes.modules.partitioned_data_mapper import BatchMapper
from stopes.utils.arrow_utils import apply_on_nested_array, apply_over_groups
from align_utils import get_uroman_tokens
from text_normalization import text_normalize
@dataclass
class LangColumnConfig:
column: str
lang_value: str | None = None
lang_column: str | None = None
@dataclass(kw_only=True)
class TextRomanizerConfig:
column: LangColumnConfig = field(
default_factory=lambda: LangColumnConfig(column="text", lang_value="en")
)
def __post_init__(self):
assert self.column.lang_value is not None or self.column.lang_column is not None
class TextRomanizer(BatchMapper):
def __init__(self, config: TextRomanizerConfig):
super().__init__(config)
self.uroman = uroman.Uroman()
def _apply_on_unique_lang_table(
self, table: pa.Table, config: LangColumnConfig
) -> pa.Table:
if config.lang_column:
assert (
len(table[config.lang_column].unique()) == 1
), "this method should be called only for unique lang values"
lang_value = table[config.lang_column][0].as_py()
else:
lang_value = config.lang_value
try:
col = table[config.column]
except KeyError:
# `table.flatten()` allows to access fields from stuct directly
# with the following name: `{column_name}.{struct_field_name}`
col = table.flatten()[config.column]
normalized_texts, tokens = apply_on_nested_array(
partial(self._apply_on_simple_column, lang_value=lang_value),
col,
)
table = table.append_column(f"{config.column}_normalized", normalized_texts)
table = table.append_column(f"{config.column}_tokens", tokens)
return table
def _apply_on_simple_column(self, col: pa.Array | pa.ChunkedArray, lang_value: str):
texts = col.to_pandas().tolist()
normalized_texts = [text_normalize(text.strip(), lang_value) for text in texts]
tokens = get_uroman_tokens(normalized_texts, self.uroman, lang_value)
return pa.array(normalized_texts, type=pa.large_string()), pa.array(
tokens, type=pa.large_string()
)
def __call__(self, table: pa.Table | None) -> pa.Table | None:
if table is None:
return table
column_config = self.config.column
table = apply_over_groups(
table,
[
column_config.lang_column
], # note that if `current_config.lang_column` is None function will be applied on the full table
partial(self._apply_on_unique_lang_table, config=column_config),
)
return table