cahya commited on
Commit
1afd82a
·
1 Parent(s): 08000e6

use gpu if possible

Browse files
Files changed (1) hide show
  1. app/app.py +3 -2
app/app.py CHANGED
@@ -16,6 +16,7 @@ mirror_url = "https://news-generator.ai-research.id/"
16
  if "MIRROR_URL" in os.environ:
17
  mirror_url = os.environ["MIRROR_URL"]
18
  hf_auth_token = os.getenv("HF_AUTH_TOKEN", False)
 
19
 
20
  MODELS = {
21
  "Indonesian Newspaper - Indonesian GPT-2 Medium": {
@@ -63,6 +64,7 @@ def get_generator(model_name: str):
63
  st.write(f"Loading the GPT2 model {model_name}, please wait...")
64
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_auth_token)
65
  model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, use_auth_token=hf_auth_token)
 
66
  model.resize_token_embeddings(len(tokenizer))
67
  return model, tokenizer
68
 
@@ -82,8 +84,7 @@ def process(text_generator, tokenizer, title: str, keywords: str, text: str,
82
  prompt = f"title: {title}\nkeywords: {keywords}\n{text}"
83
 
84
  generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
85
- # device = torch.device("cuda")
86
- # generated = generated.to(device)
87
 
88
  text_generator.eval()
89
  sample_outputs = text_generator.generate(generated,
 
16
  if "MIRROR_URL" in os.environ:
17
  mirror_url = os.environ["MIRROR_URL"]
18
  hf_auth_token = os.getenv("HF_AUTH_TOKEN", False)
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
 
21
  MODELS = {
22
  "Indonesian Newspaper - Indonesian GPT-2 Medium": {
 
64
  st.write(f"Loading the GPT2 model {model_name}, please wait...")
65
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_auth_token)
66
  model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, use_auth_token=hf_auth_token)
67
+ model.to(device)
68
  model.resize_token_embeddings(len(tokenizer))
69
  return model, tokenizer
70
 
 
84
  prompt = f"title: {title}\nkeywords: {keywords}\n{text}"
85
 
86
  generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
87
+ generated = generated.to(device)
 
88
 
89
  text_generator.eval()
90
  sample_outputs = text_generator.generate(generated,