File size: 1,626 Bytes
8d4b0c7
 
 
 
 
 
 
 
 
 
 
6563ff2
 
 
 
 
 
 
 
 
 
 
 
 
8d4b0c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6563ff2
 
 
8d4b0c7
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import json

import torch


class UstaTokenizer:
  def __init__(self, vocab_file):
    with open(vocab_file, "r") as f:
      self.vocab = json.load(f)
      self.reverse_vocab = {v: k for k, v in self.vocab.items()}

  def encode_batch(self, texts, context_length):
    sentences_tokens = []
    for text in texts:
      tokens = self.encode(text).tolist()
      if len(tokens) > context_length:
        tokens = tokens[:context_length]
      else:
        tokens = tokens + [self.vocab["<pad>"]] * (context_length - len(tokens))

      sentences_tokens.append(tokens)

    return torch.tensor(sentences_tokens)
   
  def encode(self, text):
    tokens = [] 
       
    for word in text.split():
      i = 0
      # example: states
      # state => 4
      # s => 58
      while i < len(word):
        found_match = False
        for j in range(len(word), i, -1):
          sub_word = word[i:j]
          if sub_word in self.vocab:
            tokens.append(self.vocab[sub_word])
            i = j
            found_match = True
            break
        if not found_match:
          tokens.append(self.vocab["<unk>"])
          i += 1
      tokens.append(self.vocab[" "])

    # check if text is not ends with a space
    if not text.endswith(" "):
      tokens.pop()
    return torch.tensor(tokens)
  
  def tokenize(self, text):
    token_ids = self.encode(text)
    # token_ids from tensor to list
    token_ids = token_ids.detach().numpy().tolist()

    return [self.reverse_vocab[id] for id in token_ids]

  def decode(self, ids):
    text = ""
    for id in ids:
      text += self.reverse_vocab[id]
    return text