Spaces:
Sleeping
Sleeping
File size: 3,803 Bytes
1ff8cc3 a5e716b 665916b 1ff8cc3 665916b d45d850 a5e716b 665916b d45d850 1ff8cc3 d45d850 1ff8cc3 d45d850 665916b a5e716b 665916b a5e716b 665916b a5e716b 665916b a5e716b 665916b a5e716b 665916b a5e716b 665916b a5e716b 665916b a5e716b 1ff8cc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
from transformers import T5Tokenizer, PreTrainedTokenizer
from typing import Dict, List, Optional, Union
import os
import logging
logger = logging.getLogger(__name__)
class Byt5LangTokenizer(T5Tokenizer):
"""
Кастомный токенайзер для ByT5 моделей с поддержкой распознавания таблиц.
Используется для модели vikp/surya_tablerec
"""
def __init__(
self,
vocab_file=None,
tokenizer_file=None,
eos_token="</s>",
unk_token="<unk>",
pad_token="<pad>",
extra_ids=100,
additional_special_tokens=None,
sp_model_kwargs=None,
**kwargs
):
# Вызываем родительский конструктор
super().__init__(
vocab_file=vocab_file,
tokenizer_file=tokenizer_file,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
extra_ids=extra_ids,
additional_special_tokens=additional_special_tokens,
sp_model_kwargs=sp_model_kwargs,
**kwargs
)
# Создаем byte_decoder — важно для ByT5
self.byte_decoder = {i: bytes([i]) for i in range(256)}
# Добавляем специальные токены
special_tokens = {
eos_token: self.convert_token_to_id(eos_token),
unk_token: self.convert_token_to_id(unk_token),
pad_token: self.convert_token_to_id(pad_token),
}
# Важно: Проверяем, есть ли уже атрибут special_tokens_encoder
if not hasattr(self, "special_tokens_encoder"):
self.special_tokens_encoder = {}
# Обновляем, а не перезаписываем
self.special_tokens_encoder.update(special_tokens)
# То же для decoder
if not hasattr(self, "special_tokens_decoder"):
self.special_tokens_decoder = {}
self.special_tokens_decoder.update({v: k for k, v in special_tokens.items()})
# Добавляем дополнительные атрибуты из родительского класса
if not hasattr(self, "all_special_tokens"):
self.all_special_tokens = [eos_token, unk_token, pad_token]
if not hasattr(self, "all_special_ids"):
self.all_special_ids = [self.convert_token_to_id(t) for t in self.all_special_tokens]
@property
def vocab_size(self):
return 256 + self.num_special_tokens
def get_vocab(self) -> Dict[str, int]:
vocab = {chr(i): i for i in range(256)}
vocab.update(self.special_tokens_encoder)
return vocab
def _tokenize(self, text: str) -> List[Union[int, str]]:
return list(text.encode("utf-8"))
def _convert_token_to_id(self, token: Union[str, int]) -> int:
if isinstance(token, str):
if token in self.special_tokens_encoder:
return self.special_tokens_encoder[token]
else:
try:
return ord(token)
except TypeError:
return token
return token
def _convert_id_to_token(self, index: int) -> Union[str, int]:
if index in self.special_tokens_decoder:
return self.special_tokens_decoder[index]
else:
return chr(index)
def convert_tokens_to_string(self, tokens: List[Union[str, int]]) -> str:
decoded = b""
for token in tokens:
if isinstance(token, int):
decoded += bytes([token])
else:
decoded += token.encode("utf-8")
return decoded.decode("utf-8", errors="replace")
|