TeePoat commited on
Commit
3f2ec31
·
verified ·
1 Parent(s): 00a515d

Update models/transformer/text_generator.py

Browse files
models/transformer/text_generator.py CHANGED
@@ -1,7 +1,7 @@
1
  from transformers import GPT2LMHeadModel
2
  from pathlib import Path
3
  from .utils import modified_tokenizer
4
- from .constants import CHECKPOINT_PATH
5
 
6
 
7
  class TextGenerator:
@@ -12,7 +12,7 @@ class TextGenerator:
12
  """
13
  model_path = Path(data_path) / model_name
14
  self.tokenizer = modified_tokenizer(model_path, None, data_path)
15
- self.model = GPT2LMHeadModel.from_pretrained(str(model_path), device_map="auto")
16
  self.model.eval()
17
 
18
  def generate_text(self,
 
1
  from transformers import GPT2LMHeadModel
2
  from pathlib import Path
3
  from .utils import modified_tokenizer
4
+ from .constants import CHECKPOINT_PATH, HF_TOKEN
5
 
6
 
7
  class TextGenerator:
 
12
  """
13
  model_path = Path(data_path) / model_name
14
  self.tokenizer = modified_tokenizer(model_path, None, data_path)
15
+ self.model = GPT2LMHeadModel.from_pretrained(str(model_path), device_map="auto", token=HF_TOKEN)
16
  self.model.eval()
17
 
18
  def generate_text(self,