Lyti4 commited on
Commit
1ff8cc3
·
verified ·
1 Parent(s): c99d028

Update custom_tokenizers.py

Browse files
Files changed (1) hide show
  1. custom_tokenizers.py +20 -6
custom_tokenizers.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import T5Tokenizer
2
  from typing import Dict, List, Optional, Union
3
  import os
4
  import logging
@@ -23,6 +23,7 @@ class Byt5LangTokenizer(T5Tokenizer):
23
  sp_model_kwargs=None,
24
  **kwargs
25
  ):
 
26
  super().__init__(
27
  vocab_file=vocab_file,
28
  tokenizer_file=tokenizer_file,
@@ -39,15 +40,28 @@ class Byt5LangTokenizer(T5Tokenizer):
39
  self.byte_decoder = {i: bytes([i]) for i in range(256)}
40
 
41
  # Добавляем специальные токены
42
- self.special_tokens = {
43
  eos_token: self.convert_token_to_id(eos_token),
44
  unk_token: self.convert_token_to_id(unk_token),
45
  pad_token: self.convert_token_to_id(pad_token),
46
  }
47
 
48
- # Реализуем отсутствующие атрибуты
49
- self.special_tokens_encoder = self.special_tokens
50
- self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  @property
53
  def vocab_size(self):
@@ -85,4 +99,4 @@ class Byt5LangTokenizer(T5Tokenizer):
85
  decoded += bytes([token])
86
  else:
87
  decoded += token.encode("utf-8")
88
- return decoded.decode("utf-8", errors="replace")
 
1
+ from transformers import T5Tokenizer, PreTrainedTokenizer
2
  from typing import Dict, List, Optional, Union
3
  import os
4
  import logging
 
23
  sp_model_kwargs=None,
24
  **kwargs
25
  ):
26
+ # Вызываем родительский конструктор
27
  super().__init__(
28
  vocab_file=vocab_file,
29
  tokenizer_file=tokenizer_file,
 
40
  self.byte_decoder = {i: bytes([i]) for i in range(256)}
41
 
42
  # Добавляем специальные токены
43
+ special_tokens = {
44
  eos_token: self.convert_token_to_id(eos_token),
45
  unk_token: self.convert_token_to_id(unk_token),
46
  pad_token: self.convert_token_to_id(pad_token),
47
  }
48
 
49
+ # Важно: Проверяем, есть ли уже атрибут special_tokens_encoder
50
+ if not hasattr(self, "special_tokens_encoder"):
51
+ self.special_tokens_encoder = {}
52
+ # Обновляем, а не перезаписываем
53
+ self.special_tokens_encoder.update(special_tokens)
54
+
55
+ # То же для decoder
56
+ if not hasattr(self, "special_tokens_decoder"):
57
+ self.special_tokens_decoder = {}
58
+ self.special_tokens_decoder.update({v: k for k, v in special_tokens.items()})
59
+
60
+ # Добавляем дополнительные атрибуты из родительского класса
61
+ if not hasattr(self, "all_special_tokens"):
62
+ self.all_special_tokens = [eos_token, unk_token, pad_token]
63
+ if not hasattr(self, "all_special_ids"):
64
+ self.all_special_ids = [self.convert_token_to_id(t) for t in self.all_special_tokens]
65
 
66
  @property
67
  def vocab_size(self):
 
99
  decoded += bytes([token])
100
  else:
101
  decoded += token.encode("utf-8")
102
+ return decoded.decode("utf-8", errors="replace")