File size: 2,984 Bytes
a5e716b
 
665916b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d45d850
a5e716b
665916b
d45d850
 
 
 
 
 
 
 
 
 
 
665916b
 
 
 
a5e716b
 
665916b
 
 
a5e716b
665916b
 
a5e716b
665916b
 
 
 
a5e716b
 
 
 
665916b
 
a5e716b
665916b
 
 
 
 
a5e716b
 
665916b
a5e716b
 
665916b
a5e716b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from transformers import T5Tokenizer
from typing import Dict, List, Optional, Union
import os
import logging

logger = logging.getLogger(__name__)

class Byt5LangTokenizer(T5Tokenizer):
    """
    Кастомный токенайзер для ByT5 моделей с поддержкой распознавания таблиц.
    Используется для модели vikp/surya_tablerec
    """

    def __init__(
        self,
        vocab_file=None,
        tokenizer_file=None,
        eos_token="</s>",
        unk_token="<unk>",
        pad_token="<pad>",
        extra_ids=100,
        additional_special_tokens=None,
        sp_model_kwargs=None,
        **kwargs
    ):
        super().__init__(
            vocab_file=vocab_file,
            tokenizer_file=tokenizer_file,
            eos_token=eos_token,
            unk_token=unk_token,
            pad_token=pad_token,
            extra_ids=extra_ids,
            additional_special_tokens=additional_special_tokens,
            sp_model_kwargs=sp_model_kwargs,
            **kwargs
        )

        # Создаем byte_decoder — важно для ByT5
        self.byte_decoder = {i: bytes([i]) for i in range(256)}

        # Добавляем специальные токены
        self.special_tokens = {
            eos_token: self.convert_token_to_id(eos_token),
            unk_token: self.convert_token_to_id(unk_token),
            pad_token: self.convert_token_to_id(pad_token),
        }

        # Реализуем отсутствующие атрибуты
        self.special_tokens_encoder = self.special_tokens
        self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}

    @property
    def vocab_size(self):
        return 256 + self.num_special_tokens

    def get_vocab(self) -> Dict[str, int]:
        vocab = {chr(i): i for i in range(256)}
        vocab.update(self.special_tokens_encoder)
        return vocab

    def _tokenize(self, text: str) -> List[Union[int, str]]:
        return list(text.encode("utf-8"))

    def _convert_token_to_id(self, token: Union[str, int]) -> int:
        if isinstance(token, str):
            if token in self.special_tokens_encoder:
                return self.special_tokens_encoder[token]
            else:
                try:
                    return ord(token)
                except TypeError:
                    return token
        return token

    def _convert_id_to_token(self, index: int) -> Union[str, int]:
        if index in self.special_tokens_decoder:
            return self.special_tokens_decoder[index]
        else:
            return chr(index)

    def convert_tokens_to_string(self, tokens: List[Union[str, int]]) -> str:
        decoded = b""
        for token in tokens:
            if isinstance(token, int):
                decoded += bytes([token])
            else:
                decoded += token.encode("utf-8")
        return decoded.decode("utf-8", errors="replace")