Spaces:
Running
Running
File size: 2,982 Bytes
0f1c5d2 a5e716b 665916b 0f1c5d2 665916b d45d850 a5e716b 665916b d45d850 0f1c5d2 d45d850 0f1c5d2 d45d850 665916b a5e716b 665916b a5e716b 665916b a5e716b 665916b a5e716b 665916b a5e716b 665916b a5e716b 665916b a5e716b 665916b a5e716b 1ff8cc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
from transformers import T5Tokenizer
from typing import Dict, List, Optional, Union
import os
import logging
logger = logging.getLogger(__name__)
class Byt5LangTokenizer(T5Tokenizer):
"""
Кастомный токенайзер для ByT5 моделей с поддержкой распознавания таблиц.
Используется для модели vikp/surya_table
"""
def __init__(
self,
vocab_file=None,
tokenizer_file=None,
eos_token="</s>",
unk_token="<unk>",
pad_token="<pad>",
extra_ids=100,
additional_special_tokens=None,
sp_model_kwargs=None,
**kwargs
):
super().__init__(
vocab_file=vocab_file,
tokenizer_file=tokenizer_file,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
extra_ids=extra_ids,
additional_special_tokens=additional_special_tokens,
sp_model_kwargs=sp_model_kwargs,
**kwargs
)
# Создаем byte_decoder — важно для ByT5
self.byte_decoder = {i: bytes([i]) for i in range(256)}
# Добавляем специальные токены
self.special_tokens = {
eos_token: self.convert_token_to_id(eos_token),
unk_token: self.convert_token_to_id(unk_token),
pad_token: self.convert_token_to_id(pad_token),
}
# Реализуем отсутствующие атрибуты
self.special_tokens_encoder = self.special_tokens
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
@property
def vocab_size(self):
return 256 + self.num_special_tokens
def get_vocab(self) -> Dict[str, int]:
vocab = {chr(i): i for i in range(256)}
vocab.update(self.special_tokens_encoder)
return vocab
def _tokenize(self, text: str) -> List[Union[int, str]]:
return list(text.encode("utf-8"))
def _convert_token_to_id(self, token: Union[str, int]) -> int:
if isinstance(token, str):
if token in self.special_tokens_encoder:
return self.special_tokens_encoder[token]
else:
try:
return ord(token)
except TypeError:
return token
return token
def _convert_id_to_token(self, index: int) -> Union[str, int]:
if index in self.special_tokens_decoder:
return self.special_tokens_decoder[index]
else:
return chr(index)
def convert_tokens_to_string(self, tokens: List[Union[str, int]]) -> str:
decoded = b""
for token in tokens:
if isinstance(token, int):
decoded += bytes([token])
else:
decoded += token.encode("utf-8")
return decoded.decode("utf-8", errors="replace")
|