File size: 2,982 Bytes
0f1c5d2
a5e716b
665916b
 
 
 
 
 
 
 
0f1c5d2
665916b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d45d850
a5e716b
665916b
d45d850
0f1c5d2
d45d850
 
 
 
 
0f1c5d2
 
 
d45d850
665916b
 
 
 
a5e716b
 
665916b
 
 
a5e716b
665916b
 
a5e716b
665916b
 
 
 
a5e716b
 
 
 
665916b
 
a5e716b
665916b
 
 
 
 
a5e716b
 
665916b
a5e716b
 
665916b
a5e716b
1ff8cc3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from transformers import T5Tokenizer
from typing import Dict, List, Optional, Union
import os
import logging

logger = logging.getLogger(__name__)

class Byt5LangTokenizer(T5Tokenizer):
    """
    Кастомный токенайзер для ByT5 моделей с поддержкой распознавания таблиц.
    Используется для модели vikp/surya_table
    """

    def __init__(
        self,
        vocab_file=None,
        tokenizer_file=None,
        eos_token="</s>",
        unk_token="<unk>",
        pad_token="<pad>",
        extra_ids=100,
        additional_special_tokens=None,
        sp_model_kwargs=None,
        **kwargs
    ):
        super().__init__(
            vocab_file=vocab_file,
            tokenizer_file=tokenizer_file,
            eos_token=eos_token,
            unk_token=unk_token,
            pad_token=pad_token,
            extra_ids=extra_ids,
            additional_special_tokens=additional_special_tokens,
            sp_model_kwargs=sp_model_kwargs,
            **kwargs
        )

        # Создаем byte_decoder — важно для ByT5
        self.byte_decoder = {i: bytes([i]) for i in range(256)}

        # Добавляем специальные токены
        self.special_tokens = {
            eos_token: self.convert_token_to_id(eos_token),
            unk_token: self.convert_token_to_id(unk_token),
            pad_token: self.convert_token_to_id(pad_token),
        }

        # Реализуем отсутствующие атрибуты
        self.special_tokens_encoder = self.special_tokens
        self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}

    @property
    def vocab_size(self):
        return 256 + self.num_special_tokens

    def get_vocab(self) -> Dict[str, int]:
        vocab = {chr(i): i for i in range(256)}
        vocab.update(self.special_tokens_encoder)
        return vocab

    def _tokenize(self, text: str) -> List[Union[int, str]]:
        return list(text.encode("utf-8"))

    def _convert_token_to_id(self, token: Union[str, int]) -> int:
        if isinstance(token, str):
            if token in self.special_tokens_encoder:
                return self.special_tokens_encoder[token]
            else:
                try:
                    return ord(token)
                except TypeError:
                    return token
        return token

    def _convert_id_to_token(self, index: int) -> Union[str, int]:
        if index in self.special_tokens_decoder:
            return self.special_tokens_decoder[index]
        else:
            return chr(index)

    def convert_tokens_to_string(self, tokens: List[Union[str, int]]) -> str:
        decoded = b""
        for token in tokens:
            if isinstance(token, int):
                decoded += bytes([token])
            else:
                decoded += token.encode("utf-8")
        return decoded.decode("utf-8", errors="replace")