File size: 3,803 Bytes
1ff8cc3
a5e716b
665916b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ff8cc3
665916b
 
 
 
 
 
 
 
 
 
 
 
d45d850
a5e716b
665916b
d45d850
1ff8cc3
d45d850
 
 
 
 
1ff8cc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d45d850
665916b
 
 
 
a5e716b
 
665916b
 
 
a5e716b
665916b
 
a5e716b
665916b
 
 
 
a5e716b
 
 
 
665916b
 
a5e716b
665916b
 
 
 
 
a5e716b
 
665916b
a5e716b
 
665916b
a5e716b
1ff8cc3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from transformers import T5Tokenizer, PreTrainedTokenizer
from typing import Dict, List, Optional, Union
import os
import logging

logger = logging.getLogger(__name__)

class Byt5LangTokenizer(T5Tokenizer):
    """
    Кастомный токенайзер для ByT5 моделей с поддержкой распознавания таблиц.
    Используется для модели vikp/surya_tablerec
    """

    def __init__(
        self,
        vocab_file=None,
        tokenizer_file=None,
        eos_token="</s>",
        unk_token="<unk>",
        pad_token="<pad>",
        extra_ids=100,
        additional_special_tokens=None,
        sp_model_kwargs=None,
        **kwargs
    ):
        # Вызываем родительский конструктор
        super().__init__(
            vocab_file=vocab_file,
            tokenizer_file=tokenizer_file,
            eos_token=eos_token,
            unk_token=unk_token,
            pad_token=pad_token,
            extra_ids=extra_ids,
            additional_special_tokens=additional_special_tokens,
            sp_model_kwargs=sp_model_kwargs,
            **kwargs
        )

        # Создаем byte_decoder — важно для ByT5
        self.byte_decoder = {i: bytes([i]) for i in range(256)}

        # Добавляем специальные токены
        special_tokens = {
            eos_token: self.convert_token_to_id(eos_token),
            unk_token: self.convert_token_to_id(unk_token),
            pad_token: self.convert_token_to_id(pad_token),
        }

        # Важно: Проверяем, есть ли уже атрибут special_tokens_encoder
        if not hasattr(self, "special_tokens_encoder"):
            self.special_tokens_encoder = {}
        # Обновляем, а не перезаписываем
        self.special_tokens_encoder.update(special_tokens)

        # То же для decoder
        if not hasattr(self, "special_tokens_decoder"):
            self.special_tokens_decoder = {}
        self.special_tokens_decoder.update({v: k for k, v in special_tokens.items()})

        # Добавляем дополнительные атрибуты из родительского класса
        if not hasattr(self, "all_special_tokens"):
            self.all_special_tokens = [eos_token, unk_token, pad_token]
        if not hasattr(self, "all_special_ids"):
            self.all_special_ids = [self.convert_token_to_id(t) for t in self.all_special_tokens]

    @property
    def vocab_size(self):
        return 256 + self.num_special_tokens

    def get_vocab(self) -> Dict[str, int]:
        vocab = {chr(i): i for i in range(256)}
        vocab.update(self.special_tokens_encoder)
        return vocab

    def _tokenize(self, text: str) -> List[Union[int, str]]:
        return list(text.encode("utf-8"))

    def _convert_token_to_id(self, token: Union[str, int]) -> int:
        if isinstance(token, str):
            if token in self.special_tokens_encoder:
                return self.special_tokens_encoder[token]
            else:
                try:
                    return ord(token)
                except TypeError:
                    return token
        return token

    def _convert_id_to_token(self, index: int) -> Union[str, int]:
        if index in self.special_tokens_decoder:
            return self.special_tokens_decoder[index]
        else:
            return chr(index)

    def convert_tokens_to_string(self, tokens: List[Union[str, int]]) -> str:
        decoded = b""
        for token in tokens:
            if isinstance(token, int):
                decoded += bytes([token])
            else:
                decoded += token.encode("utf-8")
        return decoded.decode("utf-8", errors="replace")