File size: 3,080 Bytes
38818c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
from functools import partial

import pyarrow as pa
import uroman
from stopes.modules.partitioned_data_mapper import BatchMapper
from stopes.utils.arrow_utils import apply_on_nested_array, apply_over_groups

from align_utils import get_uroman_tokens
from text_normalization import text_normalize


@dataclass
class LangColumnConfig:
    column: str
    lang_value: str | None = None
    lang_column: str | None = None


@dataclass(kw_only=True)
class TextRomanizerConfig:
    column: LangColumnConfig = field(
        default_factory=lambda: LangColumnConfig(column="text", lang_value="en")
    )

    def __post_init__(self):
        assert self.column.lang_value is not None or self.column.lang_column is not None


class TextRomanizer(BatchMapper):
    def __init__(self, config: TextRomanizerConfig):
        super().__init__(config)
        self.uroman = uroman.Uroman()

    def _apply_on_unique_lang_table(
        self, table: pa.Table, config: LangColumnConfig
    ) -> pa.Table:
        if config.lang_column:
            assert (
                len(table[config.lang_column].unique()) == 1
            ), "this method should be called only for unique lang values"
            lang_value = table[config.lang_column][0].as_py()
        else:
            lang_value = config.lang_value

        try:
            col = table[config.column]
        except KeyError:
            # `table.flatten()` allows to access fields from stuct directly
            # with the following name: `{column_name}.{struct_field_name}`
            col = table.flatten()[config.column]

        normalized_texts, tokens = apply_on_nested_array(
            partial(self._apply_on_simple_column, lang_value=lang_value),
            col,
        )

        table = table.append_column(f"{config.column}_normalized", normalized_texts)
        table = table.append_column(f"{config.column}_tokens", tokens)
        return table

    def _apply_on_simple_column(self, col: pa.Array | pa.ChunkedArray, lang_value: str):
        texts = col.to_pandas().tolist()
        normalized_texts = [text_normalize(text.strip(), lang_value) for text in texts]
        tokens = get_uroman_tokens(normalized_texts, self.uroman, lang_value)
        return pa.array(normalized_texts, type=pa.large_string()), pa.array(
            tokens, type=pa.large_string()
        )

    def __call__(self, table: pa.Table | None) -> pa.Table | None:
        if table is None:
            return table

        column_config = self.config.column
        table = apply_over_groups(
            table,
            [
                column_config.lang_column
            ],  # note that if `current_config.lang_column` is None function will be applied on the full table
            partial(self._apply_on_unique_lang_table, config=column_config),
        )
        return table