Spaces:
Runtime error
Runtime error
use gpu if possible
Browse files- app/app.py +3 -2
app/app.py
CHANGED
@@ -16,6 +16,7 @@ mirror_url = "https://news-generator.ai-research.id/"
|
|
16 |
if "MIRROR_URL" in os.environ:
|
17 |
mirror_url = os.environ["MIRROR_URL"]
|
18 |
hf_auth_token = os.getenv("HF_AUTH_TOKEN", False)
|
|
|
19 |
|
20 |
MODELS = {
|
21 |
"Indonesian Newspaper - Indonesian GPT-2 Medium": {
|
@@ -63,6 +64,7 @@ def get_generator(model_name: str):
|
|
63 |
st.write(f"Loading the GPT2 model {model_name}, please wait...")
|
64 |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_auth_token)
|
65 |
model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, use_auth_token=hf_auth_token)
|
|
|
66 |
model.resize_token_embeddings(len(tokenizer))
|
67 |
return model, tokenizer
|
68 |
|
@@ -82,8 +84,7 @@ def process(text_generator, tokenizer, title: str, keywords: str, text: str,
|
|
82 |
prompt = f"title: {title}\nkeywords: {keywords}\n{text}"
|
83 |
|
84 |
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
|
85 |
-
|
86 |
-
# generated = generated.to(device)
|
87 |
|
88 |
text_generator.eval()
|
89 |
sample_outputs = text_generator.generate(generated,
|
|
|
16 |
if "MIRROR_URL" in os.environ:
|
17 |
mirror_url = os.environ["MIRROR_URL"]
|
18 |
hf_auth_token = os.getenv("HF_AUTH_TOKEN", False)
|
19 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
|
21 |
MODELS = {
|
22 |
"Indonesian Newspaper - Indonesian GPT-2 Medium": {
|
|
|
64 |
st.write(f"Loading the GPT2 model {model_name}, please wait...")
|
65 |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_auth_token)
|
66 |
model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, use_auth_token=hf_auth_token)
|
67 |
+
model.to(device)
|
68 |
model.resize_token_embeddings(len(tokenizer))
|
69 |
return model, tokenizer
|
70 |
|
|
|
84 |
prompt = f"title: {title}\nkeywords: {keywords}\n{text}"
|
85 |
|
86 |
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
|
87 |
+
generated = generated.to(device)
|
|
|
88 |
|
89 |
text_generator.eval()
|
90 |
sample_outputs = text_generator.generate(generated,
|