File size: 397 Bytes
89e4bd3
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from typing import List, Union

from torch import Tensor
from transformers import T5Tokenizer


class LlmJpT5Tokenizer(T5Tokenizer):
    def decode(tokenizer: T5Tokenizer, token_ids: Union[List[int], Tensor]) -> str:
        if token_ids is None:
            return None
        elif len(token_ids) == 0:
            return ""
        else:
            return tokenizer.sp_model.decode(token_ids)